Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -145,36 +145,10 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[Response]) exte
newIter.map { batch: Array[Byte] => (batch, newIter.rowCountInLastBatch) }
}

val signal = new Object
val partitions = collection.mutable.Map.empty[Int, Array[Batch]]

val processPartition = (iter: Iterator[Batch]) => iter.toArray

// This callback is executed by the DAGScheduler thread.
// After fetching a partition, it inserts the partition into the Map, and then
// wakes up the main thread.
val resultHandler = (partitionId: Int, partition: Array[Batch]) => {
signal.synchronized {
partitions(partitionId) = partition
signal.notify()
}
()
}

spark.sparkContext.runJob(batches, processPartition, resultHandler)

// The man thread will wait until 0-th partition is available,
// then send it to client and wait for next partition.
var currentPartitionId = 0
while (currentPartitionId < numPartitions) {
val partition = signal.synchronized {
while (!partitions.contains(currentPartitionId)) {
signal.wait()
}
partitions.remove(currentPartitionId).get
}

partition.foreach { case (bytes, count) =>
def writeBatches(arrowBatches: Array[Batch]): Unit = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reason why I suggested to use locks and the main thread to write the results is exactly what this comment is trying to convey. You don't want these operations to happen inside the DAGScheduler thread. If you keep that blocked for something none scheduling related, you will stop all other scheduling. This is particularly bad in an environment where you might have multiple users running code at the same time.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have seen in higher concurrency scenarios that this does become a problem. Throughput will plateau because the DAGScheduler is doing the wrong things. I would like to avoid that.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good point! We should write it down as code comments. @zhengruifeng can you help with it?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok. let me add a comment for this

for (arrowBatch <- arrowBatches) {
val (bytes, count) = arrowBatch
val response = proto.Response.newBuilder().setClientId(clientId)
val batch = proto.Response.ArrowBatch
.newBuilder()
Expand All @@ -185,9 +159,30 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[Response]) exte
responseObserver.onNext(response.build())
numSent += 1
}
}

// Store collection results for worst case of 1 to N-1 partitions
val results = new Array[Array[Batch]](numPartitions - 1)
var lastIndex = -1 // index of last partition written

currentPartitionId += 1
// Handler to eagerly write partitions in order
val resultHandler = (partitionId: Int, partition: Array[Batch]) => {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do it need to be synchronized?

Copy link
Member Author

@HyukjinKwon HyukjinKwon Nov 11, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nope, it doesn't (because it's guided by the index). This approach is actually from the initial ordered implementation of collect with Arrow (that were in production for very long time), 82c18c2#diff-459628811d7786c705fbb2b7a381ecd2b88f183f44ab607d43b3d33ea48d390fR3282-R3318.

// If result is from next partition in order
if (partitionId - 1 == lastIndex) {
writeBatches(partition)
lastIndex += 1
// Write stored partitions that come next in order
while (lastIndex < results.length && results(lastIndex) != null) {
writeBatches(results(lastIndex))
results(lastIndex) = null
lastIndex += 1
}
} else {
// Store partitions received out of order
results(partitionId - 1) = partition
}
}
spark.sparkContext.runJob(batches, (iter: Iterator[Batch]) => iter.toArray, resultHandler)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#38468 (comment)

maybe we can create a threadpool? (shared across collect invocations)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm, I just noticed the review comment. I believe this is matched with our current implementation in PySpark. If we should improve, let's improve both paths together. I would prefer to match them and deduplicate the logic first.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 for match the implementations

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why use a thread pool if you have thread sitting around?

}

// Make sure at least 1 batch will be sent.
Expand Down