diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala
index dd95e406f2a8e..009ed64775844 100644
--- a/core/src/main/scala/org/apache/spark/SparkEnv.scala
+++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala
@@ -108,6 +108,14 @@ class SparkEnv (
pythonWorkers.get(key).foreach(_.stopWorker(worker))
}
}
+
+ private[spark]
+ def releasePythonWorker(pythonExec: String, envVars: Map[String, String], worker: Socket) {
+ synchronized {
+ val key = (pythonExec, envVars)
+ pythonWorkers.get(key).foreach(_.releaseWorker(worker))
+ }
+ }
}
object SparkEnv extends Logging {
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
index ae8010300a500..d5002fa02992b 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
@@ -23,6 +23,7 @@ import java.nio.charset.Charset
import java.util.{List => JList, ArrayList => JArrayList, Map => JMap, Collections}
import scala.collection.JavaConversions._
+import scala.collection.mutable
import scala.language.existentials
import scala.reflect.ClassTag
import scala.util.{Try, Success, Failure}
@@ -52,6 +53,7 @@ private[spark] class PythonRDD(
extends RDD[Array[Byte]](parent) {
val bufferSize = conf.getInt("spark.buffer.size", 65536)
+ val reuse_worker = conf.getBoolean("spark.python.worker.reuse", true)
override def getPartitions = parent.partitions
@@ -63,19 +65,26 @@ private[spark] class PythonRDD(
val localdir = env.blockManager.diskBlockManager.localDirs.map(
f => f.getPath()).mkString(",")
envVars += ("SPARK_LOCAL_DIRS" -> localdir) // it's also used in monitor thread
+ if (reuse_worker) {
+ envVars += ("SPARK_REUSE_WORKER" -> "1")
+ }
val worker: Socket = env.createPythonWorker(pythonExec, envVars.toMap)
// Start a thread to feed the process input from our parent's iterator
val writerThread = new WriterThread(env, worker, split, context)
+ var complete_cleanly = false
context.addTaskCompletionListener { context =>
writerThread.shutdownOnTaskCompletion()
-
- // Cleanup the worker socket. This will also cause the Python worker to exit.
- try {
- worker.close()
- } catch {
- case e: Exception => logWarning("Failed to close worker socket", e)
+ if (reuse_worker && complete_cleanly) {
+ env.releasePythonWorker(pythonExec, envVars.toMap, worker)
+ } else {
+ try {
+ worker.close()
+ } catch {
+ case e: Exception =>
+ logWarning("Failed to close worker socket", e)
+ }
}
}
@@ -115,6 +124,10 @@ private[spark] class PythonRDD(
val total = finishTime - startTime
logInfo("Times: total = %s, boot = %s, init = %s, finish = %s".format(total, boot,
init, finish))
+ val memoryBytesSpilled = stream.readLong()
+ val diskBytesSpilled = stream.readLong()
+ context.taskMetrics.memoryBytesSpilled += memoryBytesSpilled
+ context.taskMetrics.diskBytesSpilled += diskBytesSpilled
read()
case SpecialLengths.PYTHON_EXCEPTION_THROWN =>
// Signals that an exception has been thrown in python
@@ -133,6 +146,7 @@ private[spark] class PythonRDD(
stream.readFully(update)
accumulator += Collections.singletonList(update)
}
+ complete_cleanly = true
null
}
} catch {
@@ -195,11 +209,26 @@ private[spark] class PythonRDD(
PythonRDD.writeUTF(include, dataOut)
}
// Broadcast variables
- dataOut.writeInt(broadcastVars.length)
+ val oldBids = PythonRDD.getWorkerBroadcasts(worker)
+ val newBids = broadcastVars.map(_.id).toSet
+ // number of different broadcasts
+ val cnt = oldBids.diff(newBids).size + newBids.diff(oldBids).size
+ dataOut.writeInt(cnt)
+ for (bid <- oldBids) {
+ if (!newBids.contains(bid)) {
+ // remove the broadcast from worker
+ dataOut.writeLong(- bid - 1) // bid >= 0
+ oldBids.remove(bid)
+ }
+ }
for (broadcast <- broadcastVars) {
- dataOut.writeLong(broadcast.id)
- dataOut.writeInt(broadcast.value.length)
- dataOut.write(broadcast.value)
+ if (!oldBids.contains(broadcast.id)) {
+ // send new broadcast
+ dataOut.writeLong(broadcast.id)
+ dataOut.writeInt(broadcast.value.length)
+ dataOut.write(broadcast.value)
+ oldBids.add(broadcast.id)
+ }
}
dataOut.flush()
// Serialized command:
@@ -207,17 +236,18 @@ private[spark] class PythonRDD(
dataOut.write(command)
// Data values
PythonRDD.writeIteratorToStream(parent.iterator(split, context), dataOut)
+ dataOut.writeInt(SpecialLengths.END_OF_DATA_SECTION)
dataOut.flush()
} catch {
case e: Exception if context.isCompleted || context.isInterrupted =>
logDebug("Exception thrown after task completion (likely due to cleanup)", e)
+ worker.shutdownOutput()
case e: Exception =>
// We must avoid throwing exceptions here, because the thread uncaught exception handler
// will kill the whole executor (see org.apache.spark.executor.Executor).
_exception = e
- } finally {
- Try(worker.shutdownOutput()) // kill Python worker process
+ worker.shutdownOutput()
}
}
}
@@ -278,6 +308,14 @@ private object SpecialLengths {
private[spark] object PythonRDD extends Logging {
val UTF8 = Charset.forName("UTF-8")
+ // remember the broadcasts sent to each worker
+ private val workerBroadcasts = new mutable.WeakHashMap[Socket, mutable.Set[Long]]()
+ private def getWorkerBroadcasts(worker: Socket) = {
+ synchronized {
+ workerBroadcasts.getOrElseUpdate(worker, new mutable.HashSet[Long]())
+ }
+ }
+
/**
* Adapter for calling SparkContext#runJob from Python.
*
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala
index 4c4796f6c59ba..71bdf0fe1b917 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala
@@ -40,7 +40,10 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String
var daemon: Process = null
val daemonHost = InetAddress.getByAddress(Array(127, 0, 0, 1))
var daemonPort: Int = 0
- var daemonWorkers = new mutable.WeakHashMap[Socket, Int]()
+ val daemonWorkers = new mutable.WeakHashMap[Socket, Int]()
+ val idleWorkers = new mutable.Queue[Socket]()
+ var lastActivity = 0L
+ new MonitorThread().start()
var simpleWorkers = new mutable.WeakHashMap[Socket, Process]()
@@ -51,6 +54,11 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String
def create(): Socket = {
if (useDaemon) {
+ synchronized {
+ if (idleWorkers.size > 0) {
+ return idleWorkers.dequeue()
+ }
+ }
createThroughDaemon()
} else {
createSimpleWorker()
@@ -199,9 +207,44 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String
}
}
+ /**
+ * Monitor all the idle workers, kill them after timeout.
+ */
+ private class MonitorThread extends Thread(s"Idle Worker Monitor for $pythonExec") {
+
+ setDaemon(true)
+
+ override def run() {
+ while (true) {
+ synchronized {
+ if (lastActivity + IDLE_WORKER_TIMEOUT_MS < System.currentTimeMillis()) {
+ cleanupIdleWorkers()
+ lastActivity = System.currentTimeMillis()
+ }
+ }
+ Thread.sleep(10000)
+ }
+ }
+ }
+
+ private def cleanupIdleWorkers() {
+ while (idleWorkers.length > 0) {
+ val worker = idleWorkers.dequeue()
+ try {
+ // the worker will exit after closing the socket
+ worker.close()
+ } catch {
+ case e: Exception =>
+ logWarning("Failed to close worker socket", e)
+ }
+ }
+ }
+
private def stopDaemon() {
synchronized {
if (useDaemon) {
+ cleanupIdleWorkers()
+
// Request shutdown of existing daemon by sending SIGTERM
if (daemon != null) {
daemon.destroy()
@@ -220,23 +263,43 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String
}
def stopWorker(worker: Socket) {
- if (useDaemon) {
- if (daemon != null) {
- daemonWorkers.get(worker).foreach { pid =>
- // tell daemon to kill worker by pid
- val output = new DataOutputStream(daemon.getOutputStream)
- output.writeInt(pid)
- output.flush()
- daemon.getOutputStream.flush()
+ synchronized {
+ if (useDaemon) {
+ if (daemon != null) {
+ daemonWorkers.get(worker).foreach { pid =>
+ // tell daemon to kill worker by pid
+ val output = new DataOutputStream(daemon.getOutputStream)
+ output.writeInt(pid)
+ output.flush()
+ daemon.getOutputStream.flush()
+ }
}
+ } else {
+ simpleWorkers.get(worker).foreach(_.destroy())
}
- } else {
- simpleWorkers.get(worker).foreach(_.destroy())
}
worker.close()
}
+
+ def releaseWorker(worker: Socket) {
+ if (useDaemon) {
+ synchronized {
+ lastActivity = System.currentTimeMillis()
+ idleWorkers.enqueue(worker)
+ }
+ } else {
+ // Cleanup the worker socket. This will also cause the Python worker to exit.
+ try {
+ worker.close()
+ } catch {
+ case e: Exception =>
+ logWarning("Failed to close worker socket", e)
+ }
+ }
+ }
}
private object PythonWorkerFactory {
val PROCESS_WAIT_TIMEOUT_MS = 10000
+ val IDLE_WORKER_TIMEOUT_MS = 60000 // kill idle workers after 1 minute
}
diff --git a/docs/configuration.md b/docs/configuration.md
index 36178efb97103..af16489a44281 100644
--- a/docs/configuration.md
+++ b/docs/configuration.md
@@ -206,6 +206,16 @@ Apart from these, the following properties are also available, and may be useful
used during aggregation goes above this amount, it will spill the data into disks.
+
+ spark.python.worker.reuse |
+ true |
+
+ Reuse Python worker or not. If yes, it will use a fixed number of Python workers,
+ does not need to fork() a Python process for every tasks. It will be very useful
+ if there is large broadcast, then the broadcast will not be needed to transfered
+ from JVM to Python worker for every task.
+ |
+
spark.executorEnv.[EnvironmentVariableName] |
(none) |
diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md
index 3159d52787d5a..8d41fdec699e9 100644
--- a/docs/sql-programming-guide.md
+++ b/docs/sql-programming-guide.md
@@ -918,7 +918,6 @@ options.
## Migration Guide for Shark User
### Scheduling
-s
To set a [Fair Scheduler](job-scheduling.html#fair-scheduler-pools) pool for a JDBC client session,
users can set the `spark.sql.thriftserver.scheduler.pool` variable:
@@ -1110,7 +1109,7 @@ evaluated by the SQL execution engine. A full list of the functions supported c
The range of numbers is from `-9223372036854775808` to `9223372036854775807`.
- `FloatType`: Represents 4-byte single-precision floating point numbers.
- `DoubleType`: Represents 8-byte double-precision floating point numbers.
- - `DecimalType`:
+ - `DecimalType`: Represents arbitrary-precision signed decimal numbers. Backed internally by `java.math.BigDecimal`. A `BigDecimal` consists of an arbitrary precision integer unscaled value and a 32-bit integer scale.
* String type
- `StringType`: Represents character string values.
* Binary type
diff --git a/python/pyspark/daemon.py b/python/pyspark/daemon.py
index 15445abf67147..64d6202acb27d 100644
--- a/python/pyspark/daemon.py
+++ b/python/pyspark/daemon.py
@@ -23,6 +23,7 @@
import sys
import traceback
import time
+import gc
from errno import EINTR, ECHILD, EAGAIN
from socket import AF_INET, SOCK_STREAM, SOMAXCONN
from signal import SIGHUP, SIGTERM, SIGCHLD, SIG_DFL, SIG_IGN
@@ -46,17 +47,6 @@ def worker(sock):
signal.signal(SIGCHLD, SIG_DFL)
signal.signal(SIGTERM, SIG_DFL)
- # Blocks until the socket is closed by draining the input stream
- # until it raises an exception or returns EOF.
- def waitSocketClose(sock):
- try:
- while True:
- # Empty string is returned upon EOF (and only then).
- if sock.recv(4096) == '':
- return
- except:
- pass
-
# Read the socket using fdopen instead of socket.makefile() because the latter
# seems to be very slow; note that we need to dup() the file descriptor because
# otherwise writes also cause a seek that makes us miss data on the read side.
@@ -64,17 +54,13 @@ def waitSocketClose(sock):
outfile = os.fdopen(os.dup(sock.fileno()), "a+", 65536)
exit_code = 0
try:
- # Acknowledge that the fork was successful
- write_int(os.getpid(), outfile)
- outfile.flush()
worker_main(infile, outfile)
except SystemExit as exc:
- exit_code = exc.code
+ exit_code = compute_real_exit_code(exc.code)
finally:
outfile.flush()
- # The Scala side will close the socket upon task completion.
- waitSocketClose(sock)
- os._exit(compute_real_exit_code(exit_code))
+ if exit_code:
+ os._exit(exit_code)
# Cleanup zombie children
@@ -111,6 +97,8 @@ def handle_sigterm(*args):
signal.signal(SIGTERM, handle_sigterm) # Gracefully exit on SIGTERM
signal.signal(SIGHUP, SIG_IGN) # Don't die on SIGHUP
+ reuse = os.environ.get("SPARK_REUSE_WORKER")
+
# Initialization complete
try:
while True:
@@ -163,7 +151,19 @@ def handle_sigterm(*args):
# in child process
listen_sock.close()
try:
- worker(sock)
+ # Acknowledge that the fork was successful
+ outfile = sock.makefile("w")
+ write_int(os.getpid(), outfile)
+ outfile.flush()
+ outfile.close()
+ while True:
+ worker(sock)
+ if not reuse:
+ # wait for closing
+ while sock.recv(1024):
+ pass
+ break
+ gc.collect()
except:
traceback.print_exc()
os._exit(1)
diff --git a/python/pyspark/mllib/_common.py b/python/pyspark/mllib/_common.py
index bb60d3d0c8463..68f6033616726 100644
--- a/python/pyspark/mllib/_common.py
+++ b/python/pyspark/mllib/_common.py
@@ -21,7 +21,7 @@
from numpy import ndarray, float64, int64, int32, array_equal, array
from pyspark import SparkContext, RDD
from pyspark.mllib.linalg import SparseVector
-from pyspark.serializers import Serializer
+from pyspark.serializers import FramedSerializer
"""
@@ -451,18 +451,16 @@ def _serialize_rating(r):
return ba
-class RatingDeserializer(Serializer):
+class RatingDeserializer(FramedSerializer):
- def loads(self, stream):
- length = struct.unpack("!i", stream.read(4))[0]
- ba = stream.read(length)
- res = ndarray(shape=(3, ), buffer=ba, dtype=float64, offset=4)
+ def loads(self, string):
+ res = ndarray(shape=(3, ), buffer=string, dtype=float64, offset=4)
return int(res[0]), int(res[1]), res[2]
def load_stream(self, stream):
while True:
try:
- yield self.loads(stream)
+ yield self._read_with_length(stream)
except struct.error:
return
except EOFError:
diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py
index a5f9341e819a9..ec3c6f055441d 100644
--- a/python/pyspark/serializers.py
+++ b/python/pyspark/serializers.py
@@ -144,6 +144,8 @@ def _write_with_length(self, obj, stream):
def _read_with_length(self, stream):
length = read_int(stream)
+ if length == SpecialLengths.END_OF_DATA_SECTION:
+ raise EOFError
obj = stream.read(length)
if obj == "":
raise EOFError
@@ -438,6 +440,8 @@ def __init__(self, use_unicode=False):
def loads(self, stream):
length = read_int(stream)
+ if length == SpecialLengths.END_OF_DATA_SECTION:
+ raise EOFError
s = stream.read(length)
return s.decode("utf-8") if self.use_unicode else s
diff --git a/python/pyspark/shuffle.py b/python/pyspark/shuffle.py
index 49829f5280a5f..ce597cbe91e15 100644
--- a/python/pyspark/shuffle.py
+++ b/python/pyspark/shuffle.py
@@ -68,6 +68,11 @@ def _get_local_dirs(sub):
return [os.path.join(d, "python", str(os.getpid()), sub) for d in dirs]
+# global stats
+MemoryBytesSpilled = 0L
+DiskBytesSpilled = 0L
+
+
class Aggregator(object):
"""
@@ -313,10 +318,12 @@ def _spill(self):
It will dump the data in batch for better performance.
"""
+ global MemoryBytesSpilled, DiskBytesSpilled
path = self._get_spill_dir(self.spills)
if not os.path.exists(path):
os.makedirs(path)
+ used_memory = get_used_memory()
if not self.pdata:
# The data has not been partitioned, it will iterator the
# dataset once, write them into different files, has no
@@ -334,6 +341,7 @@ def _spill(self):
self.serializer.dump_stream([(k, v)], streams[h])
for s in streams:
+ DiskBytesSpilled += s.tell()
s.close()
self.data.clear()
@@ -346,9 +354,11 @@ def _spill(self):
# dump items in batch
self.serializer.dump_stream(self.pdata[i].iteritems(), f)
self.pdata[i].clear()
+ DiskBytesSpilled += os.path.getsize(p)
self.spills += 1
gc.collect() # release the memory as much as possible
+ MemoryBytesSpilled += (used_memory - get_used_memory()) << 20
def iteritems(self):
""" Return all merged items as iterator """
@@ -462,7 +472,6 @@ def __init__(self, memory_limit, serializer=None):
self.memory_limit = memory_limit
self.local_dirs = _get_local_dirs("sort")
self.serializer = serializer or BatchedSerializer(PickleSerializer(), 1024)
- self._spilled_bytes = 0
def _get_path(self, n):
""" Choose one directory for spill by number n """
@@ -476,6 +485,7 @@ def sorted(self, iterator, key=None, reverse=False):
Sort the elements in iterator, do external sort when the memory
goes above the limit.
"""
+ global MemoryBytesSpilled, DiskBytesSpilled
batch = 10
chunks, current_chunk = [], []
iterator = iter(iterator)
@@ -486,15 +496,18 @@ def sorted(self, iterator, key=None, reverse=False):
if len(chunk) < batch:
break
- if get_used_memory() > self.memory_limit:
+ used_memory = get_used_memory()
+ if used_memory > self.memory_limit:
# sort them inplace will save memory
current_chunk.sort(key=key, reverse=reverse)
path = self._get_path(len(chunks))
with open(path, 'w') as f:
self.serializer.dump_stream(current_chunk, f)
- self._spilled_bytes += os.path.getsize(path)
chunks.append(self.serializer.load_stream(open(path)))
current_chunk = []
+ gc.collect()
+ MemoryBytesSpilled += (used_memory - get_used_memory()) << 20
+ DiskBytesSpilled += os.path.getsize(path)
elif not chunks:
batch = min(batch * 2, 10000)
diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py
index b687d695b01c4..f3309a20fcffb 100644
--- a/python/pyspark/tests.py
+++ b/python/pyspark/tests.py
@@ -46,6 +46,7 @@
CloudPickleSerializer
from pyspark.shuffle import Aggregator, InMemoryMerger, ExternalMerger, ExternalSorter
from pyspark.sql import SQLContext, IntegerType
+from pyspark import shuffle
_have_scipy = False
_have_numpy = False
@@ -138,17 +139,17 @@ def test_external_sort(self):
random.shuffle(l)
sorter = ExternalSorter(1)
self.assertEquals(sorted(l), list(sorter.sorted(l)))
- self.assertGreater(sorter._spilled_bytes, 0)
- last = sorter._spilled_bytes
+ self.assertGreater(shuffle.DiskBytesSpilled, 0)
+ last = shuffle.DiskBytesSpilled
self.assertEquals(sorted(l, reverse=True), list(sorter.sorted(l, reverse=True)))
- self.assertGreater(sorter._spilled_bytes, last)
- last = sorter._spilled_bytes
+ self.assertGreater(shuffle.DiskBytesSpilled, last)
+ last = shuffle.DiskBytesSpilled
self.assertEquals(sorted(l, key=lambda x: -x), list(sorter.sorted(l, key=lambda x: -x)))
- self.assertGreater(sorter._spilled_bytes, last)
- last = sorter._spilled_bytes
+ self.assertGreater(shuffle.DiskBytesSpilled, last)
+ last = shuffle.DiskBytesSpilled
self.assertEquals(sorted(l, key=lambda x: -x, reverse=True),
list(sorter.sorted(l, key=lambda x: -x, reverse=True)))
- self.assertGreater(sorter._spilled_bytes, last)
+ self.assertGreater(shuffle.DiskBytesSpilled, last)
def test_external_sort_in_rdd(self):
conf = SparkConf().set("spark.python.worker.memory", "1m")
@@ -1222,11 +1223,46 @@ def run():
except OSError:
self.fail("daemon had been killed")
+ # run a normal job
+ rdd = self.sc.parallelize(range(100), 1)
+ self.assertEqual(100, rdd.map(str).count())
+
def test_fd_leak(self):
N = 1100 # fd limit is 1024 by default
rdd = self.sc.parallelize(range(N), N)
self.assertEquals(N, rdd.count())
+ def test_after_exception(self):
+ def raise_exception(_):
+ raise Exception()
+ rdd = self.sc.parallelize(range(100), 1)
+ self.assertRaises(Exception, lambda: rdd.foreach(raise_exception))
+ self.assertEqual(100, rdd.map(str).count())
+
+ def test_after_jvm_exception(self):
+ tempFile = tempfile.NamedTemporaryFile(delete=False)
+ tempFile.write("Hello World!")
+ tempFile.close()
+ data = self.sc.textFile(tempFile.name, 1)
+ filtered_data = data.filter(lambda x: True)
+ self.assertEqual(1, filtered_data.count())
+ os.unlink(tempFile.name)
+ self.assertRaises(Exception, lambda: filtered_data.count())
+
+ rdd = self.sc.parallelize(range(100), 1)
+ self.assertEqual(100, rdd.map(str).count())
+
+ def test_accumulator_when_reuse_worker(self):
+ from pyspark.accumulators import INT_ACCUMULATOR_PARAM
+ acc1 = self.sc.accumulator(0, INT_ACCUMULATOR_PARAM)
+ self.sc.parallelize(range(100), 20).foreach(lambda x: acc1.add(x))
+ self.assertEqual(sum(range(100)), acc1.value)
+
+ acc2 = self.sc.accumulator(0, INT_ACCUMULATOR_PARAM)
+ self.sc.parallelize(range(100), 20).foreach(lambda x: acc2.add(x))
+ self.assertEqual(sum(range(100)), acc2.value)
+ self.assertEqual(sum(range(100)), acc1.value)
+
class TestSparkSubmit(unittest.TestCase):
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index 6805063e06798..252176ac65fec 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -23,16 +23,14 @@
import time
import socket
import traceback
-# CloudPickler needs to be imported so that depicklers are registered using the
-# copy_reg module.
+
from pyspark.accumulators import _accumulatorRegistry
from pyspark.broadcast import Broadcast, _broadcastRegistry
-from pyspark.cloudpickle import CloudPickler
from pyspark.files import SparkFiles
from pyspark.serializers import write_with_length, write_int, read_long, \
write_long, read_int, SpecialLengths, UTF8Deserializer, PickleSerializer, \
CompressedSerializer
-
+from pyspark import shuffle
pickleSer = PickleSerializer()
utf8_deserializer = UTF8Deserializer()
@@ -52,6 +50,11 @@ def main(infile, outfile):
if split_index == -1: # for unit tests
return
+ # initialize global state
+ shuffle.MemoryBytesSpilled = 0
+ shuffle.DiskBytesSpilled = 0
+ _accumulatorRegistry.clear()
+
# fetch name of workdir
spark_files_dir = utf8_deserializer.loads(infile)
SparkFiles._root_directory = spark_files_dir
@@ -69,9 +72,14 @@ def main(infile, outfile):
ser = CompressedSerializer(pickleSer)
for _ in range(num_broadcast_variables):
bid = read_long(infile)
- value = ser._read_with_length(infile)
- _broadcastRegistry[bid] = Broadcast(bid, value)
+ if bid >= 0:
+ value = ser._read_with_length(infile)
+ _broadcastRegistry[bid] = Broadcast(bid, value)
+ else:
+ bid = - bid - 1
+ _broadcastRegistry.remove(bid)
+ _accumulatorRegistry.clear()
command = pickleSer._read_with_length(infile)
(func, deserializer, serializer) = command
init_time = time.time()
@@ -92,6 +100,9 @@ def main(infile, outfile):
exit(-1)
finish_time = time.time()
report_times(outfile, boot_time, init_time, finish_time)
+ write_long(shuffle.MemoryBytesSpilled, outfile)
+ write_long(shuffle.DiskBytesSpilled, outfile)
+
# Mark the beginning of the accumulators section of the output
write_int(SpecialLengths.END_OF_DATA_SECTION, outfile)
write_int(len(_accumulatorRegistry), outfile)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala
index ca69531c69a77..068cb49ef6d34 100755
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala
@@ -151,7 +151,7 @@ class SqlParser extends StandardTokenParsers with PackratParsers {
EXCEPT ^^^ { (q1: LogicalPlan, q2: LogicalPlan) => Except(q1, q2)} |
UNION ~ opt(DISTINCT) ^^^ { (q1: LogicalPlan, q2: LogicalPlan) => Distinct(Union(q1, q2)) }
)
- | insert | cache
+ | insert | cache | unCache
)
protected lazy val select: Parser[LogicalPlan] =
@@ -183,9 +183,17 @@ class SqlParser extends StandardTokenParsers with PackratParsers {
}
protected lazy val cache: Parser[LogicalPlan] =
- (CACHE ^^^ true | UNCACHE ^^^ false) ~ TABLE ~ ident ^^ {
- case doCache ~ _ ~ tableName => CacheCommand(tableName, doCache)
+ CACHE ~ TABLE ~> ident ~ opt(AS ~ select) <~ opt(";") ^^ {
+ case tableName ~ None =>
+ CacheCommand(tableName, true)
+ case tableName ~ Some(as ~ plan) =>
+ CacheTableAsSelectCommand(tableName, plan)
}
+
+ protected lazy val unCache: Parser[LogicalPlan] =
+ UNCACHE ~ TABLE ~> ident <~ opt(";") ^^ {
+ case tableName => CacheCommand(tableName, false)
+ }
protected lazy val projections: Parser[Seq[Expression]] = repsep(projection, ",")
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificRow.scala
index 088f11ee4aa53..9cbab3d5d0d0d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificRow.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificRow.scala
@@ -171,7 +171,7 @@ final class MutableByte extends MutableValue {
}
final class MutableAny extends MutableValue {
- var value: Any = 0
+ var value: Any = _
def boxed = if (isNull) null else value
def update(v: Any) = value = {
isNull = false
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/commands.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/commands.scala
index a01809c1fc5e2..8366639fa0e8b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/commands.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/commands.scala
@@ -75,3 +75,8 @@ case class DescribeCommand(
AttributeReference("data_type", StringType, nullable = false)(),
AttributeReference("comment", StringType, nullable = false)())
}
+
+/**
+ * Returned for the "CACHE TABLE tableName AS SELECT .." command.
+ */
+case class CacheTableAsSelectCommand(tableName: String, plan: LogicalPlan) extends Command
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnAccessor.scala
index 42a5a9a84f362..c9faf0852142a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnAccessor.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnAccessor.scala
@@ -50,11 +50,13 @@ private[sql] abstract class BasicColumnAccessor[T <: DataType, JvmType](
def hasNext = buffer.hasRemaining
- def extractTo(row: MutableRow, ordinal: Int) {
- columnType.setField(row, ordinal, extractSingle(buffer))
+ def extractTo(row: MutableRow, ordinal: Int): Unit = {
+ extractSingle(row, ordinal)
}
- def extractSingle(buffer: ByteBuffer): JvmType = columnType.extract(buffer)
+ def extractSingle(row: MutableRow, ordinal: Int): Unit = {
+ columnType.extract(buffer, row, ordinal)
+ }
protected def underlyingBuffer = buffer
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala
index b3ec5ded22422..2e61a981375aa 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala
@@ -68,10 +68,9 @@ private[sql] class BasicColumnBuilder[T <: DataType, JvmType](
buffer.order(ByteOrder.nativeOrder()).putInt(columnType.typeId)
}
- override def appendFrom(row: Row, ordinal: Int) {
- val field = columnType.getField(row, ordinal)
- buffer = ensureFreeSpace(buffer, columnType.actualSize(field))
- columnType.append(field, buffer)
+ override def appendFrom(row: Row, ordinal: Int): Unit = {
+ buffer = ensureFreeSpace(buffer, columnType.actualSize(row, ordinal))
+ columnType.append(row, ordinal, buffer)
}
override def build() = {
@@ -142,16 +141,16 @@ private[sql] object ColumnBuilder {
useCompression: Boolean = false): ColumnBuilder = {
val builder = (typeId match {
- case INT.typeId => new IntColumnBuilder
- case LONG.typeId => new LongColumnBuilder
- case FLOAT.typeId => new FloatColumnBuilder
- case DOUBLE.typeId => new DoubleColumnBuilder
- case BOOLEAN.typeId => new BooleanColumnBuilder
- case BYTE.typeId => new ByteColumnBuilder
- case SHORT.typeId => new ShortColumnBuilder
- case STRING.typeId => new StringColumnBuilder
- case BINARY.typeId => new BinaryColumnBuilder
- case GENERIC.typeId => new GenericColumnBuilder
+ case INT.typeId => new IntColumnBuilder
+ case LONG.typeId => new LongColumnBuilder
+ case FLOAT.typeId => new FloatColumnBuilder
+ case DOUBLE.typeId => new DoubleColumnBuilder
+ case BOOLEAN.typeId => new BooleanColumnBuilder
+ case BYTE.typeId => new ByteColumnBuilder
+ case SHORT.typeId => new ShortColumnBuilder
+ case STRING.typeId => new StringColumnBuilder
+ case BINARY.typeId => new BinaryColumnBuilder
+ case GENERIC.typeId => new GenericColumnBuilder
case TIMESTAMP.typeId => new TimestampColumnBuilder
}).asInstanceOf[ColumnBuilder]
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala
index fc343ccb995c2..203a714e03c97 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala
@@ -69,7 +69,7 @@ private[sql] class ByteColumnStats extends ColumnStats {
var lower = Byte.MaxValue
var nullCount = 0
- override def gatherStats(row: Row, ordinal: Int) {
+ override def gatherStats(row: Row, ordinal: Int): Unit = {
if (!row.isNullAt(ordinal)) {
val value = row.getByte(ordinal)
if (value > upper) upper = value
@@ -87,7 +87,7 @@ private[sql] class ShortColumnStats extends ColumnStats {
var lower = Short.MaxValue
var nullCount = 0
- override def gatherStats(row: Row, ordinal: Int) {
+ override def gatherStats(row: Row, ordinal: Int): Unit = {
if (!row.isNullAt(ordinal)) {
val value = row.getShort(ordinal)
if (value > upper) upper = value
@@ -105,7 +105,7 @@ private[sql] class LongColumnStats extends ColumnStats {
var lower = Long.MaxValue
var nullCount = 0
- override def gatherStats(row: Row, ordinal: Int) {
+ override def gatherStats(row: Row, ordinal: Int): Unit = {
if (!row.isNullAt(ordinal)) {
val value = row.getLong(ordinal)
if (value > upper) upper = value
@@ -123,7 +123,7 @@ private[sql] class DoubleColumnStats extends ColumnStats {
var lower = Double.MaxValue
var nullCount = 0
- override def gatherStats(row: Row, ordinal: Int) {
+ override def gatherStats(row: Row, ordinal: Int): Unit = {
if (!row.isNullAt(ordinal)) {
val value = row.getDouble(ordinal)
if (value > upper) upper = value
@@ -141,7 +141,7 @@ private[sql] class FloatColumnStats extends ColumnStats {
var lower = Float.MaxValue
var nullCount = 0
- override def gatherStats(row: Row, ordinal: Int) {
+ override def gatherStats(row: Row, ordinal: Int): Unit = {
if (!row.isNullAt(ordinal)) {
val value = row.getFloat(ordinal)
if (value > upper) upper = value
@@ -159,7 +159,7 @@ private[sql] class IntColumnStats extends ColumnStats {
var lower = Int.MaxValue
var nullCount = 0
- override def gatherStats(row: Row, ordinal: Int) {
+ override def gatherStats(row: Row, ordinal: Int): Unit = {
if (!row.isNullAt(ordinal)) {
val value = row.getInt(ordinal)
if (value > upper) upper = value
@@ -177,7 +177,7 @@ private[sql] class StringColumnStats extends ColumnStats {
var lower: String = null
var nullCount = 0
- override def gatherStats(row: Row, ordinal: Int) {
+ override def gatherStats(row: Row, ordinal: Int): Unit = {
if (!row.isNullAt(ordinal)) {
val value = row.getString(ordinal)
if (upper == null || value.compareTo(upper) > 0) upper = value
@@ -195,7 +195,7 @@ private[sql] class TimestampColumnStats extends ColumnStats {
var lower: Timestamp = null
var nullCount = 0
- override def gatherStats(row: Row, ordinal: Int) {
+ override def gatherStats(row: Row, ordinal: Int): Unit = {
if (!row.isNullAt(ordinal)) {
val value = row(ordinal).asInstanceOf[Timestamp]
if (upper == null || value.compareTo(upper) > 0) upper = value
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala
index 9a61600115872..198b5756676aa 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala
@@ -18,11 +18,10 @@
package org.apache.spark.sql.columnar
import java.nio.ByteBuffer
+import java.sql.Timestamp
import scala.reflect.runtime.universe.TypeTag
-import java.sql.Timestamp
-
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.expressions.MutableRow
import org.apache.spark.sql.catalyst.types._
@@ -46,16 +45,33 @@ private[sql] sealed abstract class ColumnType[T <: DataType, JvmType](
*/
def extract(buffer: ByteBuffer): JvmType
+ /**
+ * Extracts a value out of the buffer at the buffer's current position and stores in
+ * `row(ordinal)`. Subclasses should override this method to avoid boxing/unboxing costs whenever
+ * possible.
+ */
+ def extract(buffer: ByteBuffer, row: MutableRow, ordinal: Int): Unit = {
+ setField(row, ordinal, extract(buffer))
+ }
+
/**
* Appends the given value v of type T into the given ByteBuffer.
*/
- def append(v: JvmType, buffer: ByteBuffer)
+ def append(v: JvmType, buffer: ByteBuffer): Unit
+
+ /**
+ * Appends `row(ordinal)` of type T into the given ByteBuffer. Subclasses should override this
+ * method to avoid boxing/unboxing costs whenever possible.
+ */
+ def append(row: Row, ordinal: Int, buffer: ByteBuffer): Unit = {
+ append(getField(row, ordinal), buffer)
+ }
/**
- * Returns the size of the value. This is used to calculate the size of variable length types
- * such as byte arrays and strings.
+ * Returns the size of the value `row(ordinal)`. This is used to calculate the size of variable
+ * length types such as byte arrays and strings.
*/
- def actualSize(v: JvmType): Int = defaultSize
+ def actualSize(row: Row, ordinal: Int): Int = defaultSize
/**
* Returns `row(ordinal)`. Subclasses should override this method to avoid boxing/unboxing costs
@@ -67,7 +83,15 @@ private[sql] sealed abstract class ColumnType[T <: DataType, JvmType](
* Sets `row(ordinal)` to `field`. Subclasses should override this method to avoid boxing/unboxing
* costs whenever possible.
*/
- def setField(row: MutableRow, ordinal: Int, value: JvmType)
+ def setField(row: MutableRow, ordinal: Int, value: JvmType): Unit
+
+ /**
+ * Copies `from(fromOrdinal)` to `to(toOrdinal)`. Subclasses should override this method to avoid
+ * boxing/unboxing costs whenever possible.
+ */
+ def copyField(from: Row, fromOrdinal: Int, to: MutableRow, toOrdinal: Int): Unit = {
+ to(toOrdinal) = from(fromOrdinal)
+ }
/**
* Creates a duplicated copy of the value.
@@ -90,119 +114,205 @@ private[sql] abstract class NativeColumnType[T <: NativeType](
}
private[sql] object INT extends NativeColumnType(IntegerType, 0, 4) {
- def append(v: Int, buffer: ByteBuffer) {
+ def append(v: Int, buffer: ByteBuffer): Unit = {
buffer.putInt(v)
}
+ override def append(row: Row, ordinal: Int, buffer: ByteBuffer): Unit = {
+ buffer.putInt(row.getInt(ordinal))
+ }
+
def extract(buffer: ByteBuffer) = {
buffer.getInt()
}
- override def setField(row: MutableRow, ordinal: Int, value: Int) {
+ override def extract(buffer: ByteBuffer, row: MutableRow, ordinal: Int): Unit = {
+ row.setInt(ordinal, buffer.getInt())
+ }
+
+ override def setField(row: MutableRow, ordinal: Int, value: Int): Unit = {
row.setInt(ordinal, value)
}
override def getField(row: Row, ordinal: Int) = row.getInt(ordinal)
+
+ override def copyField(from: Row, fromOrdinal: Int, to: MutableRow, toOrdinal: Int): Unit = {
+ to.setInt(toOrdinal, from.getInt(fromOrdinal))
+ }
}
private[sql] object LONG extends NativeColumnType(LongType, 1, 8) {
- override def append(v: Long, buffer: ByteBuffer) {
+ override def append(v: Long, buffer: ByteBuffer): Unit = {
buffer.putLong(v)
}
+ override def append(row: Row, ordinal: Int, buffer: ByteBuffer): Unit = {
+ buffer.putLong(row.getLong(ordinal))
+ }
+
override def extract(buffer: ByteBuffer) = {
buffer.getLong()
}
- override def setField(row: MutableRow, ordinal: Int, value: Long) {
+ override def extract(buffer: ByteBuffer, row: MutableRow, ordinal: Int): Unit = {
+ row.setLong(ordinal, buffer.getLong())
+ }
+
+ override def setField(row: MutableRow, ordinal: Int, value: Long): Unit = {
row.setLong(ordinal, value)
}
override def getField(row: Row, ordinal: Int) = row.getLong(ordinal)
+
+ override def copyField(from: Row, fromOrdinal: Int, to: MutableRow, toOrdinal: Int): Unit = {
+ to.setLong(toOrdinal, from.getLong(fromOrdinal))
+ }
}
private[sql] object FLOAT extends NativeColumnType(FloatType, 2, 4) {
- override def append(v: Float, buffer: ByteBuffer) {
+ override def append(v: Float, buffer: ByteBuffer): Unit = {
buffer.putFloat(v)
}
+ override def append(row: Row, ordinal: Int, buffer: ByteBuffer): Unit = {
+ buffer.putFloat(row.getFloat(ordinal))
+ }
+
override def extract(buffer: ByteBuffer) = {
buffer.getFloat()
}
- override def setField(row: MutableRow, ordinal: Int, value: Float) {
+ override def extract(buffer: ByteBuffer, row: MutableRow, ordinal: Int): Unit = {
+ row.setFloat(ordinal, buffer.getFloat())
+ }
+
+ override def setField(row: MutableRow, ordinal: Int, value: Float): Unit = {
row.setFloat(ordinal, value)
}
override def getField(row: Row, ordinal: Int) = row.getFloat(ordinal)
+
+ override def copyField(from: Row, fromOrdinal: Int, to: MutableRow, toOrdinal: Int): Unit = {
+ to.setFloat(toOrdinal, from.getFloat(fromOrdinal))
+ }
}
private[sql] object DOUBLE extends NativeColumnType(DoubleType, 3, 8) {
- override def append(v: Double, buffer: ByteBuffer) {
+ override def append(v: Double, buffer: ByteBuffer): Unit = {
buffer.putDouble(v)
}
+ override def append(row: Row, ordinal: Int, buffer: ByteBuffer): Unit = {
+ buffer.putDouble(row.getDouble(ordinal))
+ }
+
override def extract(buffer: ByteBuffer) = {
buffer.getDouble()
}
- override def setField(row: MutableRow, ordinal: Int, value: Double) {
+ override def extract(buffer: ByteBuffer, row: MutableRow, ordinal: Int): Unit = {
+ row.setDouble(ordinal, buffer.getDouble())
+ }
+
+ override def setField(row: MutableRow, ordinal: Int, value: Double): Unit = {
row.setDouble(ordinal, value)
}
override def getField(row: Row, ordinal: Int) = row.getDouble(ordinal)
+
+ override def copyField(from: Row, fromOrdinal: Int, to: MutableRow, toOrdinal: Int): Unit = {
+ to.setDouble(toOrdinal, from.getDouble(fromOrdinal))
+ }
}
private[sql] object BOOLEAN extends NativeColumnType(BooleanType, 4, 1) {
- override def append(v: Boolean, buffer: ByteBuffer) {
- buffer.put(if (v) 1.toByte else 0.toByte)
+ override def append(v: Boolean, buffer: ByteBuffer): Unit = {
+ buffer.put(if (v) 1: Byte else 0: Byte)
+ }
+
+ override def append(row: Row, ordinal: Int, buffer: ByteBuffer): Unit = {
+ buffer.put(if (row.getBoolean(ordinal)) 1: Byte else 0: Byte)
}
override def extract(buffer: ByteBuffer) = buffer.get() == 1
- override def setField(row: MutableRow, ordinal: Int, value: Boolean) {
+ override def extract(buffer: ByteBuffer, row: MutableRow, ordinal: Int): Unit = {
+ row.setBoolean(ordinal, buffer.get() == 1)
+ }
+
+ override def setField(row: MutableRow, ordinal: Int, value: Boolean): Unit = {
row.setBoolean(ordinal, value)
}
override def getField(row: Row, ordinal: Int) = row.getBoolean(ordinal)
+
+ override def copyField(from: Row, fromOrdinal: Int, to: MutableRow, toOrdinal: Int): Unit = {
+ to.setBoolean(toOrdinal, from.getBoolean(fromOrdinal))
+ }
}
private[sql] object BYTE extends NativeColumnType(ByteType, 5, 1) {
- override def append(v: Byte, buffer: ByteBuffer) {
+ override def append(v: Byte, buffer: ByteBuffer): Unit = {
buffer.put(v)
}
+ override def append(row: Row, ordinal: Int, buffer: ByteBuffer): Unit = {
+ buffer.put(row.getByte(ordinal))
+ }
+
override def extract(buffer: ByteBuffer) = {
buffer.get()
}
- override def setField(row: MutableRow, ordinal: Int, value: Byte) {
+ override def extract(buffer: ByteBuffer, row: MutableRow, ordinal: Int): Unit = {
+ row.setByte(ordinal, buffer.get())
+ }
+
+ override def setField(row: MutableRow, ordinal: Int, value: Byte): Unit = {
row.setByte(ordinal, value)
}
override def getField(row: Row, ordinal: Int) = row.getByte(ordinal)
+
+ override def copyField(from: Row, fromOrdinal: Int, to: MutableRow, toOrdinal: Int): Unit = {
+ to.setByte(toOrdinal, from.getByte(fromOrdinal))
+ }
}
private[sql] object SHORT extends NativeColumnType(ShortType, 6, 2) {
- override def append(v: Short, buffer: ByteBuffer) {
+ override def append(v: Short, buffer: ByteBuffer): Unit = {
buffer.putShort(v)
}
+ override def append(row: Row, ordinal: Int, buffer: ByteBuffer): Unit = {
+ buffer.putShort(row.getShort(ordinal))
+ }
+
override def extract(buffer: ByteBuffer) = {
buffer.getShort()
}
- override def setField(row: MutableRow, ordinal: Int, value: Short) {
+ override def extract(buffer: ByteBuffer, row: MutableRow, ordinal: Int): Unit = {
+ row.setShort(ordinal, buffer.getShort())
+ }
+
+ override def setField(row: MutableRow, ordinal: Int, value: Short): Unit = {
row.setShort(ordinal, value)
}
override def getField(row: Row, ordinal: Int) = row.getShort(ordinal)
+
+ override def copyField(from: Row, fromOrdinal: Int, to: MutableRow, toOrdinal: Int): Unit = {
+ to.setShort(toOrdinal, from.getShort(fromOrdinal))
+ }
}
private[sql] object STRING extends NativeColumnType(StringType, 7, 8) {
- override def actualSize(v: String): Int = v.getBytes("utf-8").length + 4
+ override def actualSize(row: Row, ordinal: Int): Int = {
+ row.getString(ordinal).getBytes("utf-8").length + 4
+ }
- override def append(v: String, buffer: ByteBuffer) {
+ override def append(v: String, buffer: ByteBuffer): Unit = {
val stringBytes = v.getBytes("utf-8")
buffer.putInt(stringBytes.length).put(stringBytes, 0, stringBytes.length)
}
@@ -214,11 +324,15 @@ private[sql] object STRING extends NativeColumnType(StringType, 7, 8) {
new String(stringBytes, "utf-8")
}
- override def setField(row: MutableRow, ordinal: Int, value: String) {
+ override def setField(row: MutableRow, ordinal: Int, value: String): Unit = {
row.setString(ordinal, value)
}
override def getField(row: Row, ordinal: Int) = row.getString(ordinal)
+
+ override def copyField(from: Row, fromOrdinal: Int, to: MutableRow, toOrdinal: Int): Unit = {
+ to.setString(toOrdinal, from.getString(fromOrdinal))
+ }
}
private[sql] object TIMESTAMP extends NativeColumnType(TimestampType, 8, 12) {
@@ -228,7 +342,7 @@ private[sql] object TIMESTAMP extends NativeColumnType(TimestampType, 8, 12) {
timestamp
}
- override def append(v: Timestamp, buffer: ByteBuffer) {
+ override def append(v: Timestamp, buffer: ByteBuffer): Unit = {
buffer.putLong(v.getTime).putInt(v.getNanos)
}
@@ -236,7 +350,7 @@ private[sql] object TIMESTAMP extends NativeColumnType(TimestampType, 8, 12) {
row(ordinal).asInstanceOf[Timestamp]
}
- override def setField(row: MutableRow, ordinal: Int, value: Timestamp) {
+ override def setField(row: MutableRow, ordinal: Int, value: Timestamp): Unit = {
row(ordinal) = value
}
}
@@ -246,9 +360,11 @@ private[sql] sealed abstract class ByteArrayColumnType[T <: DataType](
defaultSize: Int)
extends ColumnType[T, Array[Byte]](typeId, defaultSize) {
- override def actualSize(v: Array[Byte]) = v.length + 4
+ override def actualSize(row: Row, ordinal: Int) = {
+ getField(row, ordinal).length + 4
+ }
- override def append(v: Array[Byte], buffer: ByteBuffer) {
+ override def append(v: Array[Byte], buffer: ByteBuffer): Unit = {
buffer.putInt(v.length).put(v, 0, v.length)
}
@@ -261,7 +377,7 @@ private[sql] sealed abstract class ByteArrayColumnType[T <: DataType](
}
private[sql] object BINARY extends ByteArrayColumnType[BinaryType.type](9, 16) {
- override def setField(row: MutableRow, ordinal: Int, value: Array[Byte]) {
+ override def setField(row: MutableRow, ordinal: Int, value: Array[Byte]): Unit = {
row(ordinal) = value
}
@@ -272,7 +388,7 @@ private[sql] object BINARY extends ByteArrayColumnType[BinaryType.type](9, 16) {
// serialized first before appending to the column `ByteBuffer`, and is also extracted as serialized
// byte array.
private[sql] object GENERIC extends ByteArrayColumnType[DataType](10, 16) {
- override def setField(row: MutableRow, ordinal: Int, value: Array[Byte]) {
+ override def setField(row: MutableRow, ordinal: Int, value: Array[Byte]): Unit = {
row(ordinal) = SparkSqlSerializer.deserialize[Any](value)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala
index 6eab2f23c18e1..8a3612cdf19be 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala
@@ -52,7 +52,7 @@ private[sql] case class InMemoryRelation(
// As in Spark, the actual work of caching is lazy.
if (_cachedColumnBuffers == null) {
val output = child.output
- val cached = child.execute().mapPartitions { baseIterator =>
+ val cached = child.execute().mapPartitions { rowIterator =>
new Iterator[CachedBatch] {
def next() = {
val columnBuilders = output.map { attribute =>
@@ -61,11 +61,9 @@ private[sql] case class InMemoryRelation(
ColumnBuilder(columnType.typeId, initialBufferSize, attribute.name, useCompression)
}.toArray
- var row: Row = null
var rowCount = 0
-
- while (baseIterator.hasNext && rowCount < batchSize) {
- row = baseIterator.next()
+ while (rowIterator.hasNext && rowCount < batchSize) {
+ val row = rowIterator.next()
var i = 0
while (i < row.length) {
columnBuilders(i).appendFrom(row, i)
@@ -80,7 +78,7 @@ private[sql] case class InMemoryRelation(
CachedBatch(columnBuilders.map(_.build()), stats)
}
- def hasNext = baseIterator.hasNext
+ def hasNext = rowIterator.hasNext
}
}.cache()
@@ -182,6 +180,7 @@ private[sql] case class InMemoryColumnarTableScan(
}
}
+ // Accumulators used for testing purposes
val readPartitions = sparkContext.accumulator(0)
val readBatches = sparkContext.accumulator(0)
@@ -191,40 +190,36 @@ private[sql] case class InMemoryColumnarTableScan(
readPartitions.setValue(0)
readBatches.setValue(0)
- relation.cachedColumnBuffers.mapPartitions { iterator =>
+ relation.cachedColumnBuffers.mapPartitions { cachedBatchIterator =>
val partitionFilter = newPredicate(
partitionFilters.reduceOption(And).getOrElse(Literal(true)),
relation.partitionStatistics.schema)
- // Find the ordinals of the requested columns. If none are requested, use the first.
- val requestedColumns = if (attributes.isEmpty) {
- Seq(0)
+ // Find the ordinals and data types of the requested columns. If none are requested, use the
+ // narrowest (the field with minimum default element size).
+ val (requestedColumnIndices, requestedColumnDataTypes) = if (attributes.isEmpty) {
+ val (narrowestOrdinal, narrowestDataType) =
+ relation.output.zipWithIndex.map { case (a, ordinal) =>
+ ordinal -> a.dataType
+ } minBy { case (_, dataType) =>
+ ColumnType(dataType).defaultSize
+ }
+ Seq(narrowestOrdinal) -> Seq(narrowestDataType)
} else {
- attributes.map(a => relation.output.indexWhere(_.exprId == a.exprId))
+ attributes.map { a =>
+ relation.output.indexWhere(_.exprId == a.exprId) -> a.dataType
+ }.unzip
}
- val rows = iterator
- // Skip pruned batches
- .filter { cachedBatch =>
- if (inMemoryPartitionPruningEnabled && !partitionFilter(cachedBatch.stats)) {
- def statsString = relation.partitionStatistics.schema
- .zip(cachedBatch.stats)
- .map { case (a, s) => s"${a.name}: $s" }
- .mkString(", ")
- logInfo(s"Skipping partition based on stats $statsString")
- false
- } else {
- readBatches += 1
- true
- }
- }
- // Build column accessors
- .map { cachedBatch =>
- requestedColumns.map(cachedBatch.buffers(_)).map(ColumnAccessor(_))
- }
- // Extract rows via column accessors
- .flatMap { columnAccessors =>
- val nextRow = new GenericMutableRow(columnAccessors.length)
+ val nextRow = new SpecificMutableRow(requestedColumnDataTypes)
+
+ def cachedBatchesToRows(cacheBatches: Iterator[CachedBatch]) = {
+ val rows = cacheBatches.flatMap { cachedBatch =>
+ // Build column accessors
+ val columnAccessors =
+ requestedColumnIndices.map(cachedBatch.buffers(_)).map(ColumnAccessor(_))
+
+ // Extract rows via column accessors
new Iterator[Row] {
override def next() = {
var i = 0
@@ -235,15 +230,38 @@ private[sql] case class InMemoryColumnarTableScan(
nextRow
}
- override def hasNext = columnAccessors.head.hasNext
+ override def hasNext = columnAccessors(0).hasNext
}
}
- if (rows.hasNext) {
- readPartitions += 1
+ if (rows.hasNext) {
+ readPartitions += 1
+ }
+
+ rows
}
- rows
+ // Do partition batch pruning if enabled
+ val cachedBatchesToScan =
+ if (inMemoryPartitionPruningEnabled) {
+ cachedBatchIterator.filter { cachedBatch =>
+ if (!partitionFilter(cachedBatch.stats)) {
+ def statsString = relation.partitionStatistics.schema
+ .zip(cachedBatch.stats)
+ .map { case (a, s) => s"${a.name}: $s" }
+ .mkString(", ")
+ logInfo(s"Skipping partition based on stats $statsString")
+ false
+ } else {
+ readBatches += 1
+ true
+ }
+ }
+ } else {
+ cachedBatchIterator
+ }
+
+ cachedBatchesToRows(cachedBatchesToScan)
}
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnAccessor.scala
index b7f8826861a2c..965782a40031b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnAccessor.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnAccessor.scala
@@ -29,7 +29,7 @@ private[sql] trait NullableColumnAccessor extends ColumnAccessor {
private var nextNullIndex: Int = _
private var pos: Int = 0
- abstract override protected def initialize() {
+ abstract override protected def initialize(): Unit = {
nullsBuffer = underlyingBuffer.duplicate().order(ByteOrder.nativeOrder())
nullCount = nullsBuffer.getInt()
nextNullIndex = if (nullCount > 0) nullsBuffer.getInt() else -1
@@ -39,7 +39,7 @@ private[sql] trait NullableColumnAccessor extends ColumnAccessor {
super.initialize()
}
- abstract override def extractTo(row: MutableRow, ordinal: Int) {
+ abstract override def extractTo(row: MutableRow, ordinal: Int): Unit = {
if (pos == nextNullIndex) {
seenNulls += 1
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnBuilder.scala
index a72970eef7aa4..f1f494ac26d0c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnBuilder.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnBuilder.scala
@@ -40,7 +40,11 @@ private[sql] trait NullableColumnBuilder extends ColumnBuilder {
protected var nullCount: Int = _
private var pos: Int = _
- abstract override def initialize(initialSize: Int, columnName: String, useCompression: Boolean) {
+ abstract override def initialize(
+ initialSize: Int,
+ columnName: String,
+ useCompression: Boolean): Unit = {
+
nulls = ByteBuffer.allocate(1024)
nulls.order(ByteOrder.nativeOrder())
pos = 0
@@ -48,7 +52,7 @@ private[sql] trait NullableColumnBuilder extends ColumnBuilder {
super.initialize(initialSize, columnName, useCompression)
}
- abstract override def appendFrom(row: Row, ordinal: Int) {
+ abstract override def appendFrom(row: Row, ordinal: Int): Unit = {
columnStats.gatherStats(row, ordinal)
if (row.isNullAt(ordinal)) {
nulls = ColumnBuilder.ensureFreeSpace(nulls, 4)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnAccessor.scala
index b4120a3d4368b..27ac5f4dbdbbc 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnAccessor.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnAccessor.scala
@@ -17,8 +17,7 @@
package org.apache.spark.sql.columnar.compression
-import java.nio.ByteBuffer
-
+import org.apache.spark.sql.catalyst.expressions.MutableRow
import org.apache.spark.sql.catalyst.types.NativeType
import org.apache.spark.sql.columnar.{ColumnAccessor, NativeColumnAccessor}
@@ -34,5 +33,7 @@ private[sql] trait CompressibleColumnAccessor[T <: NativeType] extends ColumnAcc
abstract override def hasNext = super.hasNext || decoder.hasNext
- override def extractSingle(buffer: ByteBuffer): T#JvmType = decoder.next()
+ override def extractSingle(row: MutableRow, ordinal: Int): Unit = {
+ decoder.next(row, ordinal)
+ }
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnBuilder.scala
index a5826bb033e41..628d9cec41d6b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnBuilder.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnBuilder.scala
@@ -48,12 +48,16 @@ private[sql] trait CompressibleColumnBuilder[T <: NativeType]
var compressionEncoders: Seq[Encoder[T]] = _
- abstract override def initialize(initialSize: Int, columnName: String, useCompression: Boolean) {
+ abstract override def initialize(
+ initialSize: Int,
+ columnName: String,
+ useCompression: Boolean): Unit = {
+
compressionEncoders =
if (useCompression) {
- schemes.filter(_.supports(columnType)).map(_.encoder[T])
+ schemes.filter(_.supports(columnType)).map(_.encoder[T](columnType))
} else {
- Seq(PassThrough.encoder)
+ Seq(PassThrough.encoder(columnType))
}
super.initialize(initialSize, columnName, useCompression)
}
@@ -62,17 +66,15 @@ private[sql] trait CompressibleColumnBuilder[T <: NativeType]
encoder.compressionRatio < 0.8
}
- private def gatherCompressibilityStats(row: Row, ordinal: Int) {
- val field = columnType.getField(row, ordinal)
-
+ private def gatherCompressibilityStats(row: Row, ordinal: Int): Unit = {
var i = 0
while (i < compressionEncoders.length) {
- compressionEncoders(i).gatherCompressibilityStats(field, columnType)
+ compressionEncoders(i).gatherCompressibilityStats(row, ordinal)
i += 1
}
}
- abstract override def appendFrom(row: Row, ordinal: Int) {
+ abstract override def appendFrom(row: Row, ordinal: Int): Unit = {
super.appendFrom(row, ordinal)
if (!row.isNullAt(ordinal)) {
gatherCompressibilityStats(row, ordinal)
@@ -84,7 +86,7 @@ private[sql] trait CompressibleColumnBuilder[T <: NativeType]
val typeId = nonNullBuffer.getInt()
val encoder: Encoder[T] = {
val candidate = compressionEncoders.minBy(_.compressionRatio)
- if (isWorthCompressing(candidate)) candidate else PassThrough.encoder
+ if (isWorthCompressing(candidate)) candidate else PassThrough.encoder(columnType)
}
// Header = column type ID + null count + null positions
@@ -104,7 +106,7 @@ private[sql] trait CompressibleColumnBuilder[T <: NativeType]
.putInt(nullCount)
.put(nulls)
- logInfo(s"Compressor for [$columnName]: $encoder, ratio: ${encoder.compressionRatio}")
- encoder.compress(nonNullBuffer, compressedBuffer, columnType)
+ logDebug(s"Compressor for [$columnName]: $encoder, ratio: ${encoder.compressionRatio}")
+ encoder.compress(nonNullBuffer, compressedBuffer)
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressionScheme.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressionScheme.scala
index 7797f75177893..acb06cb5376b4 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressionScheme.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressionScheme.scala
@@ -17,13 +17,15 @@
package org.apache.spark.sql.columnar.compression
-import java.nio.{ByteOrder, ByteBuffer}
+import java.nio.{ByteBuffer, ByteOrder}
+import org.apache.spark.sql.Row
+import org.apache.spark.sql.catalyst.expressions.MutableRow
import org.apache.spark.sql.catalyst.types.NativeType
import org.apache.spark.sql.columnar.{ColumnType, NativeColumnType}
private[sql] trait Encoder[T <: NativeType] {
- def gatherCompressibilityStats(value: T#JvmType, columnType: NativeColumnType[T]) {}
+ def gatherCompressibilityStats(row: Row, ordinal: Int): Unit = {}
def compressedSize: Int
@@ -33,17 +35,21 @@ private[sql] trait Encoder[T <: NativeType] {
if (uncompressedSize > 0) compressedSize.toDouble / uncompressedSize else 1.0
}
- def compress(from: ByteBuffer, to: ByteBuffer, columnType: NativeColumnType[T]): ByteBuffer
+ def compress(from: ByteBuffer, to: ByteBuffer): ByteBuffer
}
-private[sql] trait Decoder[T <: NativeType] extends Iterator[T#JvmType]
+private[sql] trait Decoder[T <: NativeType] {
+ def next(row: MutableRow, ordinal: Int): Unit
+
+ def hasNext: Boolean
+}
private[sql] trait CompressionScheme {
def typeId: Int
def supports(columnType: ColumnType[_, _]): Boolean
- def encoder[T <: NativeType]: Encoder[T]
+ def encoder[T <: NativeType](columnType: NativeColumnType[T]): Encoder[T]
def decoder[T <: NativeType](buffer: ByteBuffer, columnType: NativeColumnType[T]): Decoder[T]
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/compressionSchemes.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/compressionSchemes.scala
index 8cf9ec74ca2de..29edcf17242c5 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/compressionSchemes.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/compressionSchemes.scala
@@ -23,7 +23,8 @@ import scala.collection.mutable
import scala.reflect.ClassTag
import scala.reflect.runtime.universe.runtimeMirror
-import org.apache.spark.sql.catalyst.expressions.GenericMutableRow
+import org.apache.spark.sql.Row
+import org.apache.spark.sql.catalyst.expressions.{MutableRow, SpecificMutableRow}
import org.apache.spark.sql.catalyst.types._
import org.apache.spark.sql.columnar._
import org.apache.spark.util.Utils
@@ -33,18 +34,20 @@ private[sql] case object PassThrough extends CompressionScheme {
override def supports(columnType: ColumnType[_, _]) = true
- override def encoder[T <: NativeType] = new this.Encoder[T]
+ override def encoder[T <: NativeType](columnType: NativeColumnType[T]) = {
+ new this.Encoder[T](columnType)
+ }
override def decoder[T <: NativeType](buffer: ByteBuffer, columnType: NativeColumnType[T]) = {
new this.Decoder(buffer, columnType)
}
- class Encoder[T <: NativeType] extends compression.Encoder[T] {
+ class Encoder[T <: NativeType](columnType: NativeColumnType[T]) extends compression.Encoder[T] {
override def uncompressedSize = 0
override def compressedSize = 0
- override def compress(from: ByteBuffer, to: ByteBuffer, columnType: NativeColumnType[T]) = {
+ override def compress(from: ByteBuffer, to: ByteBuffer) = {
// Writes compression type ID and copies raw contents
to.putInt(PassThrough.typeId).put(from).rewind()
to
@@ -54,7 +57,9 @@ private[sql] case object PassThrough extends CompressionScheme {
class Decoder[T <: NativeType](buffer: ByteBuffer, columnType: NativeColumnType[T])
extends compression.Decoder[T] {
- override def next() = columnType.extract(buffer)
+ override def next(row: MutableRow, ordinal: Int): Unit = {
+ columnType.extract(buffer, row, ordinal)
+ }
override def hasNext = buffer.hasRemaining
}
@@ -63,7 +68,9 @@ private[sql] case object PassThrough extends CompressionScheme {
private[sql] case object RunLengthEncoding extends CompressionScheme {
override val typeId = 1
- override def encoder[T <: NativeType] = new this.Encoder[T]
+ override def encoder[T <: NativeType](columnType: NativeColumnType[T]) = {
+ new this.Encoder[T](columnType)
+ }
override def decoder[T <: NativeType](buffer: ByteBuffer, columnType: NativeColumnType[T]) = {
new this.Decoder(buffer, columnType)
@@ -74,24 +81,25 @@ private[sql] case object RunLengthEncoding extends CompressionScheme {
case _ => false
}
- class Encoder[T <: NativeType] extends compression.Encoder[T] {
+ class Encoder[T <: NativeType](columnType: NativeColumnType[T]) extends compression.Encoder[T] {
private var _uncompressedSize = 0
private var _compressedSize = 0
// Using `MutableRow` to store the last value to avoid boxing/unboxing cost.
- private val lastValue = new GenericMutableRow(1)
+ private val lastValue = new SpecificMutableRow(Seq(columnType.dataType))
private var lastRun = 0
override def uncompressedSize = _uncompressedSize
override def compressedSize = _compressedSize
- override def gatherCompressibilityStats(value: T#JvmType, columnType: NativeColumnType[T]) {
- val actualSize = columnType.actualSize(value)
+ override def gatherCompressibilityStats(row: Row, ordinal: Int): Unit = {
+ val value = columnType.getField(row, ordinal)
+ val actualSize = columnType.actualSize(row, ordinal)
_uncompressedSize += actualSize
if (lastValue.isNullAt(0)) {
- columnType.setField(lastValue, 0, value)
+ columnType.copyField(row, ordinal, lastValue, 0)
lastRun = 1
_compressedSize += actualSize + 4
} else {
@@ -99,37 +107,40 @@ private[sql] case object RunLengthEncoding extends CompressionScheme {
lastRun += 1
} else {
_compressedSize += actualSize + 4
- columnType.setField(lastValue, 0, value)
+ columnType.copyField(row, ordinal, lastValue, 0)
lastRun = 1
}
}
}
- override def compress(from: ByteBuffer, to: ByteBuffer, columnType: NativeColumnType[T]) = {
+ override def compress(from: ByteBuffer, to: ByteBuffer) = {
to.putInt(RunLengthEncoding.typeId)
if (from.hasRemaining) {
- var currentValue = columnType.extract(from)
+ val currentValue = new SpecificMutableRow(Seq(columnType.dataType))
var currentRun = 1
+ val value = new SpecificMutableRow(Seq(columnType.dataType))
+
+ columnType.extract(from, currentValue, 0)
while (from.hasRemaining) {
- val value = columnType.extract(from)
+ columnType.extract(from, value, 0)
- if (value == currentValue) {
+ if (value.head == currentValue.head) {
currentRun += 1
} else {
// Writes current run
- columnType.append(currentValue, to)
+ columnType.append(currentValue, 0, to)
to.putInt(currentRun)
// Resets current run
- currentValue = value
+ columnType.copyField(value, 0, currentValue, 0)
currentRun = 1
}
}
// Writes the last run
- columnType.append(currentValue, to)
+ columnType.append(currentValue, 0, to)
to.putInt(currentRun)
}
@@ -145,7 +156,7 @@ private[sql] case object RunLengthEncoding extends CompressionScheme {
private var valueCount = 0
private var currentValue: T#JvmType = _
- override def next() = {
+ override def next(row: MutableRow, ordinal: Int): Unit = {
if (valueCount == run) {
currentValue = columnType.extract(buffer)
run = buffer.getInt()
@@ -154,7 +165,7 @@ private[sql] case object RunLengthEncoding extends CompressionScheme {
valueCount += 1
}
- currentValue
+ columnType.setField(row, ordinal, currentValue)
}
override def hasNext = valueCount < run || buffer.hasRemaining
@@ -171,14 +182,16 @@ private[sql] case object DictionaryEncoding extends CompressionScheme {
new this.Decoder(buffer, columnType)
}
- override def encoder[T <: NativeType] = new this.Encoder[T]
+ override def encoder[T <: NativeType](columnType: NativeColumnType[T]) = {
+ new this.Encoder[T](columnType)
+ }
override def supports(columnType: ColumnType[_, _]) = columnType match {
case INT | LONG | STRING => true
case _ => false
}
- class Encoder[T <: NativeType] extends compression.Encoder[T] {
+ class Encoder[T <: NativeType](columnType: NativeColumnType[T]) extends compression.Encoder[T] {
// Size of the input, uncompressed, in bytes. Note that we only count until the dictionary
// overflows.
private var _uncompressedSize = 0
@@ -200,9 +213,11 @@ private[sql] case object DictionaryEncoding extends CompressionScheme {
// to store dictionary element count.
private var dictionarySize = 4
- override def gatherCompressibilityStats(value: T#JvmType, columnType: NativeColumnType[T]) {
+ override def gatherCompressibilityStats(row: Row, ordinal: Int): Unit = {
+ val value = columnType.getField(row, ordinal)
+
if (!overflow) {
- val actualSize = columnType.actualSize(value)
+ val actualSize = columnType.actualSize(row, ordinal)
count += 1
_uncompressedSize += actualSize
@@ -221,7 +236,7 @@ private[sql] case object DictionaryEncoding extends CompressionScheme {
}
}
- override def compress(from: ByteBuffer, to: ByteBuffer, columnType: NativeColumnType[T]) = {
+ override def compress(from: ByteBuffer, to: ByteBuffer) = {
if (overflow) {
throw new IllegalStateException(
"Dictionary encoding should not be used because of dictionary overflow.")
@@ -264,7 +279,9 @@ private[sql] case object DictionaryEncoding extends CompressionScheme {
}
}
- override def next() = dictionary(buffer.getShort())
+ override def next(row: MutableRow, ordinal: Int): Unit = {
+ columnType.setField(row, ordinal, dictionary(buffer.getShort()))
+ }
override def hasNext = buffer.hasRemaining
}
@@ -279,25 +296,20 @@ private[sql] case object BooleanBitSet extends CompressionScheme {
new this.Decoder(buffer).asInstanceOf[compression.Decoder[T]]
}
- override def encoder[T <: NativeType] = (new this.Encoder).asInstanceOf[compression.Encoder[T]]
+ override def encoder[T <: NativeType](columnType: NativeColumnType[T]) = {
+ (new this.Encoder).asInstanceOf[compression.Encoder[T]]
+ }
override def supports(columnType: ColumnType[_, _]) = columnType == BOOLEAN
class Encoder extends compression.Encoder[BooleanType.type] {
private var _uncompressedSize = 0
- override def gatherCompressibilityStats(
- value: Boolean,
- columnType: NativeColumnType[BooleanType.type]) {
-
+ override def gatherCompressibilityStats(row: Row, ordinal: Int): Unit = {
_uncompressedSize += BOOLEAN.defaultSize
}
- override def compress(
- from: ByteBuffer,
- to: ByteBuffer,
- columnType: NativeColumnType[BooleanType.type]) = {
-
+ override def compress(from: ByteBuffer, to: ByteBuffer) = {
to.putInt(BooleanBitSet.typeId)
// Total element count (1 byte per Boolean value)
.putInt(from.remaining)
@@ -349,7 +361,7 @@ private[sql] case object BooleanBitSet extends CompressionScheme {
private var visited: Int = 0
- override def next(): Boolean = {
+ override def next(row: MutableRow, ordinal: Int): Unit = {
val bit = visited % BITS_PER_LONG
visited += 1
@@ -357,123 +369,167 @@ private[sql] case object BooleanBitSet extends CompressionScheme {
currentWord = buffer.getLong()
}
- ((currentWord >> bit) & 1) != 0
+ row.setBoolean(ordinal, ((currentWord >> bit) & 1) != 0)
}
override def hasNext: Boolean = visited < count
}
}
-private[sql] sealed abstract class IntegralDelta[I <: IntegralType] extends CompressionScheme {
+private[sql] case object IntDelta extends CompressionScheme {
+ override def typeId: Int = 4
+
override def decoder[T <: NativeType](buffer: ByteBuffer, columnType: NativeColumnType[T]) = {
- new this.Decoder(buffer, columnType.asInstanceOf[NativeColumnType[I]])
- .asInstanceOf[compression.Decoder[T]]
+ new Decoder(buffer, INT).asInstanceOf[compression.Decoder[T]]
}
- override def encoder[T <: NativeType] = (new this.Encoder).asInstanceOf[compression.Encoder[T]]
-
- /**
- * Computes `delta = x - y`, returns `(true, delta)` if `delta` can fit into a single byte, or
- * `(false, 0: Byte)` otherwise.
- */
- protected def byteSizedDelta(x: I#JvmType, y: I#JvmType): (Boolean, Byte)
+ override def encoder[T <: NativeType](columnType: NativeColumnType[T]) = {
+ (new Encoder).asInstanceOf[compression.Encoder[T]]
+ }
- /**
- * Simply computes `x + delta`
- */
- protected def addDelta(x: I#JvmType, delta: Byte): I#JvmType
+ override def supports(columnType: ColumnType[_, _]) = columnType == INT
- class Encoder extends compression.Encoder[I] {
- private var _compressedSize: Int = 0
+ class Encoder extends compression.Encoder[IntegerType.type] {
+ protected var _compressedSize: Int = 0
+ protected var _uncompressedSize: Int = 0
- private var _uncompressedSize: Int = 0
+ override def compressedSize = _compressedSize
+ override def uncompressedSize = _uncompressedSize
- private var prev: I#JvmType = _
+ private var prevValue: Int = _
- private var initial = true
+ override def gatherCompressibilityStats(row: Row, ordinal: Int): Unit = {
+ val value = row.getInt(ordinal)
+ val delta = value - prevValue
- override def gatherCompressibilityStats(value: I#JvmType, columnType: NativeColumnType[I]) {
- _uncompressedSize += columnType.defaultSize
+ _compressedSize += 1
- if (initial) {
- initial = false
- _compressedSize += 1 + columnType.defaultSize
- } else {
- val (smallEnough, _) = byteSizedDelta(value, prev)
- _compressedSize += (if (smallEnough) 1 else 1 + columnType.defaultSize)
+ // If this is the first integer to be compressed, or the delta is out of byte range, then give
+ // up compressing this integer.
+ if (_uncompressedSize == 0 || delta <= Byte.MinValue || delta > Byte.MaxValue) {
+ _compressedSize += INT.defaultSize
}
- prev = value
+ _uncompressedSize += INT.defaultSize
+ prevValue = value
}
- override def compress(from: ByteBuffer, to: ByteBuffer, columnType: NativeColumnType[I]) = {
+ override def compress(from: ByteBuffer, to: ByteBuffer): ByteBuffer = {
to.putInt(typeId)
if (from.hasRemaining) {
- var prev = columnType.extract(from)
+ var prev = from.getInt()
to.put(Byte.MinValue)
- columnType.append(prev, to)
+ to.putInt(prev)
while (from.hasRemaining) {
- val current = columnType.extract(from)
- val (smallEnough, delta) = byteSizedDelta(current, prev)
+ val current = from.getInt()
+ val delta = current - prev
prev = current
- if (smallEnough) {
- to.put(delta)
+ if (Byte.MinValue < delta && delta <= Byte.MaxValue) {
+ to.put(delta.toByte)
} else {
to.put(Byte.MinValue)
- columnType.append(current, to)
+ to.putInt(current)
}
}
}
- to.rewind()
- to
+ to.rewind().asInstanceOf[ByteBuffer]
}
-
- override def uncompressedSize = _uncompressedSize
-
- override def compressedSize = _compressedSize
}
- class Decoder(buffer: ByteBuffer, columnType: NativeColumnType[I])
- extends compression.Decoder[I] {
+ class Decoder(buffer: ByteBuffer, columnType: NativeColumnType[IntegerType.type])
+ extends compression.Decoder[IntegerType.type] {
+
+ private var prev: Int = _
- private var prev: I#JvmType = _
+ override def hasNext: Boolean = buffer.hasRemaining
- override def next() = {
+ override def next(row: MutableRow, ordinal: Int): Unit = {
val delta = buffer.get()
- prev = if (delta > Byte.MinValue) addDelta(prev, delta) else columnType.extract(buffer)
- prev
+ prev = if (delta > Byte.MinValue) prev + delta else buffer.getInt()
+ row.setInt(ordinal, prev)
}
-
- override def hasNext = buffer.hasRemaining
}
}
-private[sql] case object IntDelta extends IntegralDelta[IntegerType.type] {
- override val typeId = 4
+private[sql] case object LongDelta extends CompressionScheme {
+ override def typeId: Int = 5
- override def supports(columnType: ColumnType[_, _]) = columnType == INT
+ override def decoder[T <: NativeType](buffer: ByteBuffer, columnType: NativeColumnType[T]) = {
+ new Decoder(buffer, LONG).asInstanceOf[compression.Decoder[T]]
+ }
+
+ override def encoder[T <: NativeType](columnType: NativeColumnType[T]) = {
+ (new Encoder).asInstanceOf[compression.Encoder[T]]
+ }
- override protected def addDelta(x: Int, delta: Byte) = x + delta
+ override def supports(columnType: ColumnType[_, _]) = columnType == LONG
+
+ class Encoder extends compression.Encoder[LongType.type] {
+ protected var _compressedSize: Int = 0
+ protected var _uncompressedSize: Int = 0
+
+ override def compressedSize = _compressedSize
+ override def uncompressedSize = _uncompressedSize
+
+ private var prevValue: Long = _
+
+ override def gatherCompressibilityStats(row: Row, ordinal: Int): Unit = {
+ val value = row.getLong(ordinal)
+ val delta = value - prevValue
+
+ _compressedSize += 1
- override protected def byteSizedDelta(x: Int, y: Int): (Boolean, Byte) = {
- val delta = x - y
- if (math.abs(delta) <= Byte.MaxValue) (true, delta.toByte) else (false, 0: Byte)
+ // If this is the first long integer to be compressed, or the delta is out of byte range, then
+ // give up compressing this long integer.
+ if (_uncompressedSize == 0 || delta <= Byte.MinValue || delta > Byte.MaxValue) {
+ _compressedSize += LONG.defaultSize
+ }
+
+ _uncompressedSize += LONG.defaultSize
+ prevValue = value
+ }
+
+ override def compress(from: ByteBuffer, to: ByteBuffer): ByteBuffer = {
+ to.putInt(typeId)
+
+ if (from.hasRemaining) {
+ var prev = from.getLong()
+ to.put(Byte.MinValue)
+ to.putLong(prev)
+
+ while (from.hasRemaining) {
+ val current = from.getLong()
+ val delta = current - prev
+ prev = current
+
+ if (Byte.MinValue < delta && delta <= Byte.MaxValue) {
+ to.put(delta.toByte)
+ } else {
+ to.put(Byte.MinValue)
+ to.putLong(current)
+ }
+ }
+ }
+
+ to.rewind().asInstanceOf[ByteBuffer]
+ }
}
-}
-private[sql] case object LongDelta extends IntegralDelta[LongType.type] {
- override val typeId = 5
+ class Decoder(buffer: ByteBuffer, columnType: NativeColumnType[LongType.type])
+ extends compression.Decoder[LongType.type] {
- override def supports(columnType: ColumnType[_, _]) = columnType == LONG
+ private var prev: Long = _
- override protected def addDelta(x: Long, delta: Byte) = x + delta
+ override def hasNext: Boolean = buffer.hasRemaining
- override protected def byteSizedDelta(x: Long, y: Long): (Boolean, Byte) = {
- val delta = x - y
- if (math.abs(delta) <= Byte.MaxValue) (true, delta.toByte) else (false, 0: Byte)
+ override def next(row: MutableRow, ordinal: Int): Unit = {
+ val delta = buffer.get()
+ prev = if (delta > Byte.MinValue) prev + delta else buffer.getLong()
+ row.setLong(ordinal, prev)
+ }
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index 7943d6e1b6fb5..45687d960404c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -305,6 +305,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
Seq(execution.ExplainCommand(logicalPlan, plan.output, extended)(context))
case logical.CacheCommand(tableName, cache) =>
Seq(execution.CacheCommand(tableName, cache)(context))
+ case logical.CacheTableAsSelectCommand(tableName, plan) =>
+ Seq(execution.CacheTableAsSelectCommand(tableName, plan))
case _ => Nil
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala
index 94543fc95b470..1535291f7908b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala
@@ -166,3 +166,22 @@ case class DescribeCommand(child: SparkPlan, output: Seq[Attribute])(
child.output.map(field => Row(field.name, field.dataType.toString, null))
}
}
+
+/**
+ * :: DeveloperApi ::
+ */
+@DeveloperApi
+case class CacheTableAsSelectCommand(tableName: String, plan: LogicalPlan)
+ extends LeafNode with Command {
+
+ override protected[sql] lazy val sideEffectResult = {
+ sqlContext.catalog.registerTable(None, tableName, sqlContext.executePlan(plan).analyzed)
+ sqlContext.cacheTable(tableName)
+ // It does the caching eager.
+ sqlContext.table(tableName).count
+ Seq.empty[Row]
+ }
+
+ override def output: Seq[Attribute] = Seq.empty
+
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/test/TestSQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/test/TestSQLContext.scala
index f2389f8f0591e..265b67737c475 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/test/TestSQLContext.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/test/TestSQLContext.scala
@@ -18,8 +18,13 @@
package org.apache.spark.sql.test
import org.apache.spark.{SparkConf, SparkContext}
-import org.apache.spark.sql.SQLContext
+import org.apache.spark.sql.{SQLConf, SQLContext}
/** A SQLContext that can be used for local testing. */
object TestSQLContext
- extends SQLContext(new SparkContext("local", "TestSQLContext", new SparkConf()))
+ extends SQLContext(new SparkContext("local[2]", "TestSQLContext", new SparkConf())) {
+
+ /** Fewer partitions to speed up testing. */
+ override private[spark] def numShufflePartitions: Int =
+ getConf(SQLConf.SHUFFLE_PARTITIONS, "5").toInt
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala
index befef46d93973..591592841e9fe 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala
@@ -119,4 +119,17 @@ class CachedTableSuite extends QueryTest {
}
assert(!TestSQLContext.isCached("testData"), "Table 'testData' should not be cached")
}
+
+ test("CACHE TABLE tableName AS SELECT Star Table") {
+ TestSQLContext.sql("CACHE TABLE testCacheTable AS SELECT * FROM testData")
+ TestSQLContext.sql("SELECT * FROM testCacheTable WHERE key = 1").collect()
+ assert(TestSQLContext.isCached("testCacheTable"), "Table 'testCacheTable' should be cached")
+ TestSQLContext.uncacheTable("testCacheTable")
+ }
+
+ test("'CACHE TABLE tableName AS SELECT ..'") {
+ TestSQLContext.sql("CACHE TABLE testCacheTable AS SELECT * FROM testData")
+ assert(TestSQLContext.isCached("testCacheTable"), "Table 'testCacheTable' should be cached")
+ TestSQLContext.uncacheTable("testCacheTable")
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala
index cde91ceb68c98..0cdbb3167ce36 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala
@@ -35,7 +35,7 @@ class ColumnStatsSuite extends FunSuite {
def testColumnStats[T <: NativeType, U <: ColumnStats](
columnStatsClass: Class[U],
columnType: NativeColumnType[T],
- initialStatistics: Row) {
+ initialStatistics: Row): Unit = {
val columnStatsName = columnStatsClass.getSimpleName
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala
index 75f653f3280bd..4fb1ecf1d532b 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala
@@ -23,6 +23,7 @@ import java.sql.Timestamp
import org.scalatest.FunSuite
import org.apache.spark.Logging
+import org.apache.spark.sql.catalyst.expressions.GenericMutableRow
import org.apache.spark.sql.catalyst.types._
import org.apache.spark.sql.columnar.ColumnarTestUtils._
import org.apache.spark.sql.execution.SparkSqlSerializer
@@ -46,10 +47,12 @@ class ColumnTypeSuite extends FunSuite with Logging {
def checkActualSize[T <: DataType, JvmType](
columnType: ColumnType[T, JvmType],
value: JvmType,
- expected: Int) {
+ expected: Int): Unit = {
assertResult(expected, s"Wrong actualSize for $columnType") {
- columnType.actualSize(value)
+ val row = new GenericMutableRow(1)
+ columnType.setField(row, 0, value)
+ columnType.actualSize(row, 0)
}
}
@@ -147,7 +150,7 @@ class ColumnTypeSuite extends FunSuite with Logging {
def testNativeColumnType[T <: NativeType](
columnType: NativeColumnType[T],
putter: (ByteBuffer, T#JvmType) => Unit,
- getter: (ByteBuffer) => T#JvmType) {
+ getter: (ByteBuffer) => T#JvmType): Unit = {
testColumnType[T, T#JvmType](columnType, putter, getter)
}
@@ -155,7 +158,7 @@ class ColumnTypeSuite extends FunSuite with Logging {
def testColumnType[T <: DataType, JvmType](
columnType: ColumnType[T, JvmType],
putter: (ByteBuffer, JvmType) => Unit,
- getter: (ByteBuffer) => JvmType) {
+ getter: (ByteBuffer) => JvmType): Unit = {
val buffer = ByteBuffer.allocate(DEFAULT_BUFFER_SIZE)
val seq = (0 until 4).map(_ => makeRandomValue(columnType))
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala
index 0e3c67f5eed29..c1278248ef655 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala
@@ -19,7 +19,7 @@ package org.apache.spark.sql.columnar
import org.apache.spark.sql.catalyst.expressions.Row
import org.apache.spark.sql.test.TestSQLContext
-import org.apache.spark.sql.{SQLConf, QueryTest, TestData}
+import org.apache.spark.sql.{QueryTest, TestData}
class InMemoryColumnarQuerySuite extends QueryTest {
import org.apache.spark.sql.TestData._
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala
index 3baa6f8ec0c83..6c9a9ab6c3418 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala
@@ -45,7 +45,9 @@ class NullableColumnAccessorSuite extends FunSuite {
testNullableColumnAccessor(_)
}
- def testNullableColumnAccessor[T <: DataType, JvmType](columnType: ColumnType[T, JvmType]) {
+ def testNullableColumnAccessor[T <: DataType, JvmType](
+ columnType: ColumnType[T, JvmType]): Unit = {
+
val typeName = columnType.getClass.getSimpleName.stripSuffix("$")
val nullRow = makeNullRow(1)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala
index a77262534a352..f54a21eb4fbb1 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala
@@ -41,7 +41,9 @@ class NullableColumnBuilderSuite extends FunSuite {
testNullableColumnBuilder(_)
}
- def testNullableColumnBuilder[T <: DataType, JvmType](columnType: ColumnType[T, JvmType]) {
+ def testNullableColumnBuilder[T <: DataType, JvmType](
+ columnType: ColumnType[T, JvmType]): Unit = {
+
val typeName = columnType.getClass.getSimpleName.stripSuffix("$")
test(s"$typeName column builder: empty column") {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala
index 5d2fd4959197c..69e0adbd3ee0d 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala
@@ -28,7 +28,7 @@ class PartitionBatchPruningSuite extends FunSuite with BeforeAndAfterAll with Be
val originalColumnBatchSize = columnBatchSize
val originalInMemoryPartitionPruning = inMemoryPartitionPruning
- override protected def beforeAll() {
+ override protected def beforeAll(): Unit = {
// Make a table with 5 partitions, 2 batches per partition, 10 elements per batch
setConf(SQLConf.COLUMN_BATCH_SIZE, "10")
val rawData = sparkContext.makeRDD(1 to 100, 5).map(IntegerData)
@@ -38,7 +38,7 @@ class PartitionBatchPruningSuite extends FunSuite with BeforeAndAfterAll with Be
setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, "true")
}
- override protected def afterAll() {
+ override protected def afterAll(): Unit = {
setConf(SQLConf.COLUMN_BATCH_SIZE, originalColumnBatchSize.toString)
setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, originalInMemoryPartitionPruning.toString)
}
@@ -76,7 +76,7 @@ class PartitionBatchPruningSuite extends FunSuite with BeforeAndAfterAll with Be
filter: String,
expectedQueryResult: Seq[Int],
expectedReadPartitions: Int,
- expectedReadBatches: Int) {
+ expectedReadBatches: Int): Unit = {
test(filter) {
val query = sql(s"SELECT * FROM intData WHERE $filter")
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/BooleanBitSetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/BooleanBitSetSuite.scala
index e01cc8b4d20f2..d9e488e0ffd16 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/BooleanBitSetSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/BooleanBitSetSuite.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.columnar.compression
import org.scalatest.FunSuite
import org.apache.spark.sql.Row
+import org.apache.spark.sql.catalyst.expressions.GenericMutableRow
import org.apache.spark.sql.columnar.{NoopColumnStats, BOOLEAN}
import org.apache.spark.sql.columnar.ColumnarTestUtils._
@@ -72,10 +73,14 @@ class BooleanBitSetSuite extends FunSuite {
buffer.rewind().position(headerSize + 4)
val decoder = BooleanBitSet.decoder(buffer, BOOLEAN)
+ val mutableRow = new GenericMutableRow(1)
if (values.nonEmpty) {
values.foreach {
assert(decoder.hasNext)
- assertResult(_, "Wrong decoded value")(decoder.next())
+ assertResult(_, "Wrong decoded value") {
+ decoder.next(mutableRow, 0)
+ mutableRow.getBoolean(0)
+ }
}
}
assert(!decoder.hasNext)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/DictionaryEncodingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/DictionaryEncodingSuite.scala
index d2969d906c943..1cdb909146d57 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/DictionaryEncodingSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/DictionaryEncodingSuite.scala
@@ -21,6 +21,7 @@ import java.nio.ByteBuffer
import org.scalatest.FunSuite
+import org.apache.spark.sql.catalyst.expressions.GenericMutableRow
import org.apache.spark.sql.catalyst.types.NativeType
import org.apache.spark.sql.columnar._
import org.apache.spark.sql.columnar.ColumnarTestUtils._
@@ -67,7 +68,7 @@ class DictionaryEncodingSuite extends FunSuite {
val buffer = builder.build()
val headerSize = CompressionScheme.columnHeaderSize(buffer)
// 4 extra bytes for dictionary size
- val dictionarySize = 4 + values.map(columnType.actualSize).sum
+ val dictionarySize = 4 + rows.map(columnType.actualSize(_, 0)).sum
// 2 bytes for each `Short`
val compressedSize = 4 + dictionarySize + 2 * inputSeq.length
// 4 extra bytes for compression scheme type ID
@@ -97,11 +98,15 @@ class DictionaryEncodingSuite extends FunSuite {
buffer.rewind().position(headerSize + 4)
val decoder = DictionaryEncoding.decoder(buffer, columnType)
+ val mutableRow = new GenericMutableRow(1)
if (inputSeq.nonEmpty) {
inputSeq.foreach { i =>
assert(decoder.hasNext)
- assertResult(values(i), "Wrong decoded value")(decoder.next())
+ assertResult(values(i), "Wrong decoded value") {
+ decoder.next(mutableRow, 0)
+ columnType.getField(mutableRow, 0)
+ }
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/IntegralDeltaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/IntegralDeltaSuite.scala
index 322f447c24840..73f31c0233343 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/IntegralDeltaSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/IntegralDeltaSuite.scala
@@ -31,7 +31,7 @@ class IntegralDeltaSuite extends FunSuite {
def testIntegralDelta[I <: IntegralType](
columnStats: ColumnStats,
columnType: NativeColumnType[I],
- scheme: IntegralDelta[I]) {
+ scheme: CompressionScheme) {
def skeleton(input: Seq[I#JvmType]) {
// -------------
@@ -96,10 +96,15 @@ class IntegralDeltaSuite extends FunSuite {
buffer.rewind().position(headerSize + 4)
val decoder = scheme.decoder(buffer, columnType)
+ val mutableRow = new GenericMutableRow(1)
+
if (input.nonEmpty) {
input.foreach{
assert(decoder.hasNext)
- assertResult(_, "Wrong decoded value")(decoder.next())
+ assertResult(_, "Wrong decoded value") {
+ decoder.next(mutableRow, 0)
+ columnType.getField(mutableRow, 0)
+ }
}
}
assert(!decoder.hasNext)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/RunLengthEncodingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/RunLengthEncodingSuite.scala
index 218c09ac26362..4ce2552112c92 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/RunLengthEncodingSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/RunLengthEncodingSuite.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql.columnar.compression
import org.scalatest.FunSuite
+import org.apache.spark.sql.catalyst.expressions.GenericMutableRow
import org.apache.spark.sql.catalyst.types.NativeType
import org.apache.spark.sql.columnar._
import org.apache.spark.sql.columnar.ColumnarTestUtils._
@@ -57,7 +58,7 @@ class RunLengthEncodingSuite extends FunSuite {
// Compression scheme ID + compressed contents
val compressedSize = 4 + inputRuns.map { case (index, _) =>
// 4 extra bytes each run for run length
- columnType.actualSize(values(index)) + 4
+ columnType.actualSize(rows(index), 0) + 4
}.sum
// 4 extra bytes for compression scheme type ID
@@ -80,11 +81,15 @@ class RunLengthEncodingSuite extends FunSuite {
buffer.rewind().position(headerSize + 4)
val decoder = RunLengthEncoding.decoder(buffer, columnType)
+ val mutableRow = new GenericMutableRow(1)
if (inputSeq.nonEmpty) {
inputSeq.foreach { i =>
assert(decoder.hasNext)
- assertResult(values(i), "Wrong decoded value")(decoder.next())
+ assertResult(values(i), "Wrong decoded value") {
+ decoder.next(mutableRow, 0)
+ columnType.getField(mutableRow, 0)
+ }
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala
index b0a06cd3ca090..08f7358446b29 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala
@@ -58,8 +58,7 @@ case class AllDataTypes(
doubleField: Double,
shortField: Short,
byteField: Byte,
- booleanField: Boolean,
- binaryField: Array[Byte])
+ booleanField: Boolean)
case class AllDataTypesWithNonPrimitiveType(
stringField: String,
@@ -70,13 +69,14 @@ case class AllDataTypesWithNonPrimitiveType(
shortField: Short,
byteField: Byte,
booleanField: Boolean,
- binaryField: Array[Byte],
array: Seq[Int],
arrayContainsNull: Seq[Option[Int]],
map: Map[Int, Long],
mapValueContainsNull: Map[Int, Option[Long]],
data: Data)
+case class BinaryData(binaryData: Array[Byte])
+
class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterAll {
TestData // Load test data tables.
@@ -108,26 +108,26 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA
test("Read/Write All Types") {
val tempDir = getTempFilePath("parquetTest").getCanonicalPath
val range = (0 to 255)
- TestSQLContext.sparkContext.parallelize(range)
- .map(x => AllDataTypes(s"$x", x, x.toLong, x.toFloat, x.toDouble, x.toShort, x.toByte, x % 2 == 0,
- (0 to x).map(_.toByte).toArray))
- .saveAsParquetFile(tempDir)
- val result = parquetFile(tempDir).collect()
- range.foreach {
- i =>
- assert(result(i).getString(0) == s"$i", s"row $i String field did not match, got ${result(i).getString(0)}")
- assert(result(i).getInt(1) === i)
- assert(result(i).getLong(2) === i.toLong)
- assert(result(i).getFloat(3) === i.toFloat)
- assert(result(i).getDouble(4) === i.toDouble)
- assert(result(i).getShort(5) === i.toShort)
- assert(result(i).getByte(6) === i.toByte)
- assert(result(i).getBoolean(7) === (i % 2 == 0))
- assert(result(i)(8) === (0 to i).map(_.toByte).toArray)
- }
+ val data = sparkContext.parallelize(range)
+ .map(x => AllDataTypes(s"$x", x, x.toLong, x.toFloat, x.toDouble, x.toShort, x.toByte, x % 2 == 0))
+
+ data.saveAsParquetFile(tempDir)
+
+ checkAnswer(
+ parquetFile(tempDir),
+ data.toSchemaRDD.collect().toSeq)
}
- test("Treat binary as string") {
+ test("read/write binary data") {
+ // Since equality for Array[Byte] is broken we test this separately.
+ val tempDir = getTempFilePath("parquetTest").getCanonicalPath
+ sparkContext.parallelize(BinaryData("test".getBytes("utf8")) :: Nil).saveAsParquetFile(tempDir)
+ parquetFile(tempDir)
+ .map(r => new String(r(0).asInstanceOf[Array[Byte]], "utf8"))
+ .collect().toSeq == Seq("test")
+ }
+
+ ignore("Treat binary as string") {
val oldIsParquetBinaryAsString = TestSQLContext.isParquetBinaryAsString
// Create the test file.
@@ -142,37 +142,16 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA
StructField("c2", BinaryType, false) :: Nil)
val schemaRDD1 = applySchema(rowRDD, schema)
schemaRDD1.saveAsParquetFile(path)
- val resultWithBinary = parquetFile(path).collect
- range.foreach {
- i =>
- assert(resultWithBinary(i).getInt(0) === i)
- assert(resultWithBinary(i)(1) === s"val_$i".getBytes)
- }
-
- TestSQLContext.setConf(SQLConf.PARQUET_BINARY_AS_STRING, "true")
- // This ParquetRelation always use Parquet types to derive output.
- val parquetRelation = new ParquetRelation(
- path.toString,
- Some(TestSQLContext.sparkContext.hadoopConfiguration),
- TestSQLContext) {
- override val output =
- ParquetTypesConverter.convertToAttributes(
- ParquetTypesConverter.readMetaData(new Path(path), conf).getFileMetaData.getSchema,
- TestSQLContext.isParquetBinaryAsString)
- }
- val schemaRDD = new SchemaRDD(TestSQLContext, parquetRelation)
- val resultWithString = schemaRDD.collect
- range.foreach {
- i =>
- assert(resultWithString(i).getInt(0) === i)
- assert(resultWithString(i)(1) === s"val_$i")
- }
+ checkAnswer(
+ parquetFile(path).select('c1, 'c2.cast(StringType)),
+ schemaRDD1.select('c1, 'c2.cast(StringType)).collect().toSeq)
- schemaRDD.registerTempTable("tmp")
+ setConf(SQLConf.PARQUET_BINARY_AS_STRING, "true")
+ parquetFile(path).printSchema()
checkAnswer(
- sql("SELECT c1, c2 FROM tmp WHERE c2 = 'val_5' OR c2 = 'val_7'"),
- (5, "val_5") ::
- (7, "val_7") :: Nil)
+ parquetFile(path),
+ schemaRDD1.select('c1, 'c2.cast(StringType)).collect().toSeq)
+
// Set it back.
TestSQLContext.setConf(SQLConf.PARQUET_BINARY_AS_STRING, oldIsParquetBinaryAsString.toString)
@@ -275,34 +254,19 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA
test("Read/Write All Types with non-primitive type") {
val tempDir = getTempFilePath("parquetTest").getCanonicalPath
val range = (0 to 255)
- TestSQLContext.sparkContext.parallelize(range)
+ val data = sparkContext.parallelize(range)
.map(x => AllDataTypesWithNonPrimitiveType(
s"$x", x, x.toLong, x.toFloat, x.toDouble, x.toShort, x.toByte, x % 2 == 0,
- (0 to x).map(_.toByte).toArray,
(0 until x),
(0 until x).map(Option(_).filter(_ % 3 == 0)),
(0 until x).map(i => i -> i.toLong).toMap,
(0 until x).map(i => i -> Option(i.toLong)).toMap + (x -> None),
Data((0 until x), Nested(x, s"$x"))))
- .saveAsParquetFile(tempDir)
- val result = parquetFile(tempDir).collect()
- range.foreach {
- i =>
- assert(result(i).getString(0) == s"$i", s"row $i String field did not match, got ${result(i).getString(0)}")
- assert(result(i).getInt(1) === i)
- assert(result(i).getLong(2) === i.toLong)
- assert(result(i).getFloat(3) === i.toFloat)
- assert(result(i).getDouble(4) === i.toDouble)
- assert(result(i).getShort(5) === i.toShort)
- assert(result(i).getByte(6) === i.toByte)
- assert(result(i).getBoolean(7) === (i % 2 == 0))
- assert(result(i)(8) === (0 to i).map(_.toByte).toArray)
- assert(result(i)(9) === (0 until i))
- assert(result(i)(10) === (0 until i).map(i => if (i % 3 == 0) i else null))
- assert(result(i)(11) === (0 until i).map(i => i -> i.toLong).toMap)
- assert(result(i)(12) === (0 until i).map(i => i -> i.toLong).toMap + (i -> null))
- assert(result(i)(13) === new GenericRow(Array[Any]((0 until i), new GenericRow(Array[Any](i, s"$i")))))
- }
+ data.saveAsParquetFile(tempDir)
+
+ checkAnswer(
+ parquetFile(tempDir),
+ data.toSchemaRDD.collect().toSeq)
}
test("self-join parquet files") {
@@ -399,23 +363,6 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA
}
}
- test("Saving case class RDD table to file and reading it back in") {
- val file = getTempFilePath("parquet")
- val path = file.toString
- val rdd = TestSQLContext.sparkContext.parallelize((1 to 100))
- .map(i => TestRDDEntry(i, s"val_$i"))
- rdd.saveAsParquetFile(path)
- val readFile = parquetFile(path)
- readFile.registerTempTable("tmpx")
- val rdd_copy = sql("SELECT * FROM tmpx").collect()
- val rdd_orig = rdd.collect()
- for(i <- 0 to 99) {
- assert(rdd_copy(i).apply(0) === rdd_orig(i).key, s"key error in line $i")
- assert(rdd_copy(i).apply(1) === rdd_orig(i).value, s"value error in line $i")
- }
- Utils.deleteRecursively(file)
- }
-
test("Read a parquet file instead of a directory") {
val file = getTempFilePath("parquet")
val path = file.toString
@@ -448,32 +395,19 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA
sql("INSERT OVERWRITE INTO dest SELECT * FROM source").collect()
val rdd_copy1 = sql("SELECT * FROM dest").collect()
assert(rdd_copy1.size === 100)
- assert(rdd_copy1(0).apply(0) === 1)
- assert(rdd_copy1(0).apply(1) === "val_1")
- // TODO: why does collecting break things? It seems InsertIntoParquet::execute() is
- // executed twice otherwise?!
+
sql("INSERT INTO dest SELECT * FROM source")
- val rdd_copy2 = sql("SELECT * FROM dest").collect()
+ val rdd_copy2 = sql("SELECT * FROM dest").collect().sortBy(_.getInt(0))
assert(rdd_copy2.size === 200)
- assert(rdd_copy2(0).apply(0) === 1)
- assert(rdd_copy2(0).apply(1) === "val_1")
- assert(rdd_copy2(99).apply(0) === 100)
- assert(rdd_copy2(99).apply(1) === "val_100")
- assert(rdd_copy2(100).apply(0) === 1)
- assert(rdd_copy2(100).apply(1) === "val_1")
Utils.deleteRecursively(dirname)
}
test("Insert (appending) to same table via Scala API") {
- // TODO: why does collecting break things? It seems InsertIntoParquet::execute() is
- // executed twice otherwise?!
sql("INSERT INTO testsource SELECT * FROM testsource")
val double_rdd = sql("SELECT * FROM testsource").collect()
assert(double_rdd != null)
assert(double_rdd.size === 30)
- for(i <- (0 to 14)) {
- assert(double_rdd(i) === double_rdd(i+15), s"error: lines $i and ${i+15} to not match")
- }
+
// let's restore the original test data
Utils.deleteRecursively(ParquetTestData.testDir)
ParquetTestData.writeFile()
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
index 21ecf17028dbc..86d2aad71607a 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
@@ -229,7 +229,13 @@ private[hive] object HiveQl {
SetCommand(Some(key), Some(value))
}
} else if (sql.trim.toLowerCase.startsWith("cache table")) {
- CacheCommand(sql.trim.drop(12).trim, true)
+ sql.trim.drop(12).trim.split(" ").toSeq match {
+ case Seq(tableName) =>
+ CacheCommand(tableName, true)
+ case Seq(tableName,as, select@_*) =>
+ CacheTableAsSelectCommand(tableName,
+ createPlan(sql.trim.drop(12 + tableName.length() + as.length() + 2)))
+ }
} else if (sql.trim.toLowerCase.startsWith("uncache table")) {
CacheCommand(sql.trim.drop(14).trim, false)
} else if (sql.trim.toLowerCase.startsWith("add jar")) {
@@ -243,15 +249,7 @@ private[hive] object HiveQl {
} else if (sql.trim.startsWith("!")) {
ShellCommand(sql.drop(1))
} else {
- val tree = getAst(sql)
- if (nativeCommands contains tree.getText) {
- NativeCommand(sql)
- } else {
- nodeToPlan(tree) match {
- case NativePlaceholder => NativeCommand(sql)
- case other => other
- }
- }
+ createPlan(sql)
}
} catch {
case e: Exception => throw new ParseException(sql, e)
@@ -262,6 +260,19 @@ private[hive] object HiveQl {
""".stripMargin)
}
}
+
+ /** Creates LogicalPlan for a given HiveQL string. */
+ def createPlan(sql: String) = {
+ val tree = getAst(sql)
+ if (nativeCommands contains tree.getText) {
+ NativeCommand(sql)
+ } else {
+ nodeToPlan(tree) match {
+ case NativePlaceholder => NativeCommand(sql)
+ case other => other
+ }
+ }
+ }
def parseDdl(ddl: String): Seq[Attribute] = {
val tree =
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala
index 329f80cad471e..84fafcde63d05 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala
@@ -25,16 +25,14 @@ import org.apache.hadoop.hive.ql.metadata.{Partition => HivePartition, Table =>
import org.apache.hadoop.hive.ql.plan.{PlanUtils, TableDesc}
import org.apache.hadoop.hive.serde2.Deserializer
import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector
-
+import org.apache.hadoop.hive.serde2.objectinspector.primitive._
import org.apache.hadoop.io.Writable
import org.apache.hadoop.mapred.{FileInputFormat, InputFormat, JobConf}
import org.apache.spark.SerializableWritable
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.rdd.{EmptyRDD, HadoopRDD, RDD, UnionRDD}
-
-import org.apache.spark.sql.catalyst.expressions.{Attribute, Row, GenericMutableRow, Literal, Cast}
-import org.apache.spark.sql.catalyst.types.DataType
+import org.apache.spark.sql.catalyst.expressions._
/**
* A trait for subclasses that handle table scans.
@@ -108,12 +106,12 @@ class HadoopTableReader(
val hadoopRDD = createHadoopRdd(tableDesc, inputPathStr, ifc)
val attrsWithIndex = attributes.zipWithIndex
- val mutableRow = new GenericMutableRow(attrsWithIndex.length)
+ val mutableRow = new SpecificMutableRow(attributes.map(_.dataType))
+
val deserializedHadoopRDD = hadoopRDD.mapPartitions { iter =>
val hconf = broadcastedHiveConf.value.value
val deserializer = deserializerClass.newInstance()
deserializer.initialize(hconf, tableDesc.getProperties)
-
HadoopTableReader.fillObject(iter, deserializer, attrsWithIndex, mutableRow)
}
@@ -164,33 +162,32 @@ class HadoopTableReader(
val tableDesc = relation.tableDesc
val broadcastedHiveConf = _broadcastedHiveConf
val localDeserializer = partDeserializer
- val mutableRow = new GenericMutableRow(attributes.length)
-
- // split the attributes (output schema) into 2 categories:
- // (partition keys, ordinal), (normal attributes, ordinal), the ordinal mean the
- // index of the attribute in the output Row.
- val (partitionKeys, attrs) = attributes.zipWithIndex.partition(attr => {
- relation.partitionKeys.indexOf(attr._1) >= 0
- })
-
- def fillPartitionKeys(parts: Array[String], row: GenericMutableRow) = {
- partitionKeys.foreach { case (attr, ordinal) =>
- // get partition key ordinal for a given attribute
- val partOridinal = relation.partitionKeys.indexOf(attr)
- row(ordinal) = Cast(Literal(parts(partOridinal)), attr.dataType).eval(null)
+ val mutableRow = new SpecificMutableRow(attributes.map(_.dataType))
+
+ // Splits all attributes into two groups, partition key attributes and those that are not.
+ // Attached indices indicate the position of each attribute in the output schema.
+ val (partitionKeyAttrs, nonPartitionKeyAttrs) =
+ attributes.zipWithIndex.partition { case (attr, _) =>
+ relation.partitionKeys.contains(attr)
+ }
+
+ def fillPartitionKeys(rawPartValues: Array[String], row: MutableRow) = {
+ partitionKeyAttrs.foreach { case (attr, ordinal) =>
+ val partOrdinal = relation.partitionKeys.indexOf(attr)
+ row(ordinal) = Cast(Literal(rawPartValues(partOrdinal)), attr.dataType).eval(null)
}
}
- // fill the partition key for the given MutableRow Object
+
+ // Fill all partition keys to the given MutableRow object
fillPartitionKeys(partValues, mutableRow)
- val hivePartitionRDD = createHadoopRdd(tableDesc, inputPathStr, ifc)
- hivePartitionRDD.mapPartitions { iter =>
+ createHadoopRdd(tableDesc, inputPathStr, ifc).mapPartitions { iter =>
val hconf = broadcastedHiveConf.value.value
val deserializer = localDeserializer.newInstance()
deserializer.initialize(hconf, partProps)
- // fill the non partition key attributes
- HadoopTableReader.fillObject(iter, deserializer, attrs, mutableRow)
+ // fill the non partition key attributes
+ HadoopTableReader.fillObject(iter, deserializer, nonPartitionKeyAttrs, mutableRow)
}
}.toSeq
@@ -257,38 +254,64 @@ private[hive] object HadoopTableReader extends HiveInspectors {
}
/**
- * Transform the raw data(Writable object) into the Row object for an iterable input
- * @param iter Iterable input which represented as Writable object
- * @param deserializer Deserializer associated with the input writable object
- * @param attrs Represents the row attribute names and its zero-based position in the MutableRow
- * @param row reusable MutableRow object
- *
- * @return Iterable Row object that transformed from the given iterable input.
+ * Transform all given raw `Writable`s into `Row`s.
+ *
+ * @param iterator Iterator of all `Writable`s to be transformed
+ * @param deserializer The `Deserializer` associated with the input `Writable`
+ * @param nonPartitionKeyAttrs Attributes that should be filled together with their corresponding
+ * positions in the output schema
+ * @param mutableRow A reusable `MutableRow` that should be filled
+ * @return An `Iterator[Row]` transformed from `iterator`
*/
def fillObject(
- iter: Iterator[Writable],
+ iterator: Iterator[Writable],
deserializer: Deserializer,
- attrs: Seq[(Attribute, Int)],
- row: GenericMutableRow): Iterator[Row] = {
+ nonPartitionKeyAttrs: Seq[(Attribute, Int)],
+ mutableRow: MutableRow): Iterator[Row] = {
+
val soi = deserializer.getObjectInspector().asInstanceOf[StructObjectInspector]
- // get the field references according to the attributes(output of the reader) required
- val fieldRefs = attrs.map { case (attr, idx) => (soi.getStructFieldRef(attr.name), idx) }
+ val (fieldRefs, fieldOrdinals) = nonPartitionKeyAttrs.map { case (attr, ordinal) =>
+ soi.getStructFieldRef(attr.name) -> ordinal
+ }.unzip
+
+ // Builds specific unwrappers ahead of time according to object inspector types to avoid pattern
+ // matching and branching costs per row.
+ val unwrappers: Seq[(Any, MutableRow, Int) => Unit] = fieldRefs.map {
+ _.getFieldObjectInspector match {
+ case oi: BooleanObjectInspector =>
+ (value: Any, row: MutableRow, ordinal: Int) => row.setBoolean(ordinal, oi.get(value))
+ case oi: ByteObjectInspector =>
+ (value: Any, row: MutableRow, ordinal: Int) => row.setByte(ordinal, oi.get(value))
+ case oi: ShortObjectInspector =>
+ (value: Any, row: MutableRow, ordinal: Int) => row.setShort(ordinal, oi.get(value))
+ case oi: IntObjectInspector =>
+ (value: Any, row: MutableRow, ordinal: Int) => row.setInt(ordinal, oi.get(value))
+ case oi: LongObjectInspector =>
+ (value: Any, row: MutableRow, ordinal: Int) => row.setLong(ordinal, oi.get(value))
+ case oi: FloatObjectInspector =>
+ (value: Any, row: MutableRow, ordinal: Int) => row.setFloat(ordinal, oi.get(value))
+ case oi: DoubleObjectInspector =>
+ (value: Any, row: MutableRow, ordinal: Int) => row.setDouble(ordinal, oi.get(value))
+ case oi =>
+ (value: Any, row: MutableRow, ordinal: Int) => row(ordinal) = unwrapData(value, oi)
+ }
+ }
// Map each tuple to a row object
- iter.map { value =>
+ iterator.map { value =>
val raw = deserializer.deserialize(value)
- var idx = 0;
- while (idx < fieldRefs.length) {
- val fieldRef = fieldRefs(idx)._1
- val fieldIdx = fieldRefs(idx)._2
- val fieldValue = soi.getStructFieldData(raw, fieldRef)
-
- row(fieldIdx) = unwrapData(fieldValue, fieldRef.getFieldObjectInspector())
-
- idx += 1
+ var i = 0
+ while (i < fieldRefs.length) {
+ val fieldValue = soi.getStructFieldData(raw, fieldRefs(i))
+ if (fieldValue == null) {
+ mutableRow.setNullAt(fieldOrdinals(i))
+ } else {
+ unwrappers(i)(fieldValue, mutableRow, fieldOrdinals(i))
+ }
+ i += 1
}
- row: Row
+ mutableRow: Row
}
}
}
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala
index a3bfd3a8f1fd2..70fb15259e7d7 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala
@@ -35,12 +35,13 @@ import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.plans.logical.{CacheCommand, LogicalPlan, NativeCommand}
import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.hive._
+import org.apache.spark.sql.SQLConf
/* Implicit conversions */
import scala.collection.JavaConversions._
object TestHive
- extends TestHiveContext(new SparkContext("local", "TestSQLContext", new SparkConf()))
+ extends TestHiveContext(new SparkContext("local[2]", "TestSQLContext", new SparkConf()))
/**
* A locally running test instance of Spark's Hive execution engine.
@@ -90,6 +91,10 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) {
override def executePlan(plan: LogicalPlan): this.QueryExecution =
new this.QueryExecution { val logical = plan }
+ /** Fewer partitions to speed up testing. */
+ override private[spark] def numShufflePartitions: Int =
+ getConf(SQLConf.SHUFFLE_PARTITIONS, "5").toInt
+
/**
* Returns the value of specified environmental variable as a [[java.io.File]] after checking
* to ensure it exists
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala
index 671c3b162f875..79cc7a3fcc7d6 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala
@@ -250,9 +250,9 @@ abstract class HiveComparisonTest
}
try {
- // MINOR HACK: You must run a query before calling reset the first time.
- TestHive.sql("SHOW TABLES")
- if (reset) { TestHive.reset() }
+ if (reset) {
+ TestHive.reset()
+ }
val hiveCacheFiles = queryList.zipWithIndex.map {
case (queryString, i) =>
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala
index 6bf8d18a5c32c..8c8a8b124ac69 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala
@@ -295,8 +295,16 @@ class HiveQuerySuite extends HiveComparisonTest {
"SELECT (CASE WHEN key > 2 THEN 3 WHEN 2 > key THEN 2 ELSE 0 END) FROM src WHERE key < 15")
test("implement identity function using case statement") {
- val actual = sql("SELECT (CASE key WHEN key THEN key END) FROM src").collect().toSet
- val expected = sql("SELECT key FROM src").collect().toSet
+ val actual = sql("SELECT (CASE key WHEN key THEN key END) FROM src")
+ .map { case Row(i: Int) => i }
+ .collect()
+ .toSet
+
+ val expected = sql("SELECT key FROM src")
+ .map { case Row(i: Int) => i }
+ .collect()
+ .toSet
+
assert(actual === expected)
}
@@ -559,9 +567,9 @@ class HiveQuerySuite extends HiveComparisonTest {
val testVal = "test.val.0"
val nonexistentKey = "nonexistent"
val KV = "([^=]+)=([^=]*)".r
- def collectResults(rdd: SchemaRDD): Set[(String, String)] =
- rdd.collect().map {
- case Row(key: String, value: String) => key -> value
+ def collectResults(rdd: SchemaRDD): Set[(String, String)] =
+ rdd.collect().map {
+ case Row(key: String, value: String) => key -> value
case Row(KV(key, value)) => key -> value
}.toSet
clear()