Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions dev/deps/spark-deps-hadoop-3-hive-2.3
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,11 @@ arrow-memory-core/18.0.0//arrow-memory-core-18.0.0.jar
arrow-memory-netty-buffer-patch/18.0.0//arrow-memory-netty-buffer-patch-18.0.0.jar
arrow-memory-netty/18.0.0//arrow-memory-netty-18.0.0.jar
arrow-vector/18.0.0//arrow-vector-18.0.0.jar
asm-analysis/9.2//asm-analysis-9.2.jar
asm-commons/9.2//asm-commons-9.2.jar
asm-tree/9.2//asm-tree-9.2.jar
asm-util/9.2//asm-util-9.2.jar
asm/9.2//asm-9.2.jar
Copy link
Contributor

@LuciferYang LuciferYang Nov 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Spark master branch is currently using asm version 9.7.1, and I think we should unify to version 9.7.1 because this involves support for Java versions, asm 9.2 does not support Java 19+

audience-annotations/0.12.0//audience-annotations-0.12.0.jar
avro-ipc/1.12.0//avro-ipc-1.12.0.jar
avro-mapred/1.12.0//avro-mapred-1.12.0.jar
Expand Down Expand Up @@ -142,10 +147,18 @@ jersey-server/3.0.16//jersey-server-3.0.16.jar
jettison/1.5.4//jettison-1.5.4.jar
jetty-util-ajax/11.0.24//jetty-util-ajax-11.0.24.jar
jetty-util/11.0.24//jetty-util-11.0.24.jar
jffi/1.3.9//jffi-1.3.9.jar
jffi/1.3.9/native/jffi-1.3.9-native.jar
jjwt-api/0.12.6//jjwt-api-0.12.6.jar
jline/2.14.6//jline-2.14.6.jar
jline/3.26.3//jline-3.26.3.jar
jna/5.14.0//jna-5.14.0.jar
jnr-a64asm/1.0.0//jnr-a64asm-1.0.0.jar
jnr-constants/0.10.3//jnr-constants-0.10.3.jar
jnr-enxio/0.32.13//jnr-enxio-0.32.13.jar
jnr-ffi/2.2.11//jnr-ffi-2.2.11.jar
jnr-unixsocket/0.38.18//jnr-unixsocket-0.38.18.jar
jnr-x86asm/1.0.2//jnr-x86asm-1.0.2.jar
joda-time/2.13.0//joda-time-2.13.0.jar
jodd-core/3.5.2//jodd-core-3.5.2.jar
jpam/1.1//jpam-1.1.jar
Expand Down
7 changes: 4 additions & 3 deletions python/pyspark/sql/streaming/stateful_processor_api_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,11 @@ class StatefulProcessorHandleState(Enum):


class StatefulProcessorApiClient:
def __init__(self, state_server_port: int, key_schema: StructType) -> None:
def __init__(self, state_server_id: int, key_schema: StructType) -> None:
self.key_schema = key_schema
self._client_socket = socket.socket()
self._client_socket.connect(("localhost", state_server_port))
self._client_socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
server_address = f"./uds_{state_server_id}.sock"
self._client_socket.connect(server_address)
self.sockfile = self._client_socket.makefile(
"rwb", int(os.environ.get("SPARK_BUFFER_SIZE", 65536))
)
Expand Down
8 changes: 4 additions & 4 deletions python/pyspark/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -1449,7 +1449,7 @@ def mapper(_, it):
def read_udfs(pickleSer, infile, eval_type):
runner_conf = {}

state_server_port = None
state_server_id = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need this change?

key_schema = None
if eval_type in (
PythonEvalType.SQL_ARROW_BATCHED_UDF,
Expand Down Expand Up @@ -1481,7 +1481,7 @@ def read_udfs(pickleSer, infile, eval_type):
eval_type == PythonEvalType.SQL_TRANSFORM_WITH_STATE_PANDAS_UDF
or eval_type == PythonEvalType.SQL_TRANSFORM_WITH_STATE_PANDAS_INIT_STATE_UDF
):
state_server_port = read_int(infile)
state_server_id = read_int(infile)
key_schema = StructType.fromJson(json.loads(utf8_deserializer.loads(infile)))

# NOTE: if timezone is set here, that implies respectSessionTimeZone is True
Expand Down Expand Up @@ -1695,7 +1695,7 @@ def mapper(a):
)
parsed_offsets = extract_key_value_indexes(arg_offsets)
ser.key_offsets = parsed_offsets[0][0]
stateful_processor_api_client = StatefulProcessorApiClient(state_server_port, key_schema)
stateful_processor_api_client = StatefulProcessorApiClient(state_server_id, key_schema)

# Create function like this:
# mapper a: f([a[0]], [a[0], a[1]])
Expand Down Expand Up @@ -1728,7 +1728,7 @@ def values_gen():
parsed_offsets = extract_key_value_indexes(arg_offsets)
ser.key_offsets = parsed_offsets[0][0]
ser.init_key_offsets = parsed_offsets[1][0]
stateful_processor_api_client = StatefulProcessorApiClient(state_server_port, key_schema)
stateful_processor_api_client = StatefulProcessorApiClient(state_server_id, key_schema)

def mapper(a):
key = a[0]
Expand Down
5 changes: 5 additions & 0 deletions sql/core/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,11 @@
<artifactId>htmlunit3-driver</artifactId>
<scope>test</scope>
</dependency>
<dependency>
Copy link
Contributor

@LuciferYang LuciferYang Nov 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. should add this to dependency management first
  2. for newly added dependencies, the corresponding LICENSE should be maintained. am i right ?@dongjoon-hyun @yaooqinn

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For non-ASF ones, we will include their licenses in our licenses-binary directory and LICENSE-binary file, and also update the NOTICE-binary document for any applicable notices.

<groupId>com.github.jnr</groupId>
<artifactId>jnr-unixsocket</artifactId>
<version>0.38.18</version>
</dependency>
<!-- Explicit declaration of bouncy-castle dependencies are
needed for maven test builds on later hadoop releases.-->
<dependency>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,12 @@
package org.apache.spark.sql.execution.python

import java.io.DataOutputStream
import java.net.ServerSocket
import java.nio.file.{Files, Path}

import scala.concurrent.ExecutionContext

import jnr.unixsocket.UnixServerSocketChannel
import jnr.unixsocket.UnixSocketAddress
import org.apache.arrow.vector.VectorSchemaRoot
import org.apache.arrow.vector.ipc.ArrowStreamWriter

Expand Down Expand Up @@ -181,7 +183,9 @@ abstract class TransformWithStateInPandasPythonBaseRunner[I](
protected val sqlConf = SQLConf.get
protected val arrowMaxRecordsPerBatch = sqlConf.arrowMaxRecordsPerBatch

private var stateServerSocketPort: Int = 0
private val serverId = TransformWithStateInPandasStateServer.allocateServerId()

private val socketPath = s"./uds_$serverId.sock"

override protected val workerConf: Map[String, String] = initialWorkerConf +
(SQLConf.ARROW_EXECUTION_MAX_RECORDS_PER_BATCH.key -> arrowMaxRecordsPerBatch.toString)
Expand All @@ -195,53 +199,57 @@ abstract class TransformWithStateInPandasPythonBaseRunner[I](

override protected def handleMetadataBeforeExec(stream: DataOutputStream): Unit = {
super.handleMetadataBeforeExec(stream)
// Also write the port number for state server
stream.writeInt(stateServerSocketPort)
// Also write the service id for state server
stream.writeInt(serverId)
PythonRDD.writeUTF(groupingKeySchema.json, stream)
}

override def compute(
inputIterator: Iterator[I],
partitionIndex: Int,
context: TaskContext): Iterator[ColumnarBatch] = {
var stateServerSocket: ServerSocket = null
var serverChannel: UnixServerSocketChannel = null
var failed = false
try {
stateServerSocket = new ServerSocket( /* port = */ 0,
/* backlog = */ 1)
stateServerSocketPort = stateServerSocket.getLocalPort
val socketFile = Path.of(socketPath)
Files.deleteIfExists(socketFile)
val serverAddress = new UnixSocketAddress(socketPath)
serverChannel = UnixServerSocketChannel.open()
serverChannel.socket().bind(serverAddress)
} catch {
case e: Throwable =>
failed = true
throw e
} finally {
if (failed) {
closeServerSocketChannelSilently(stateServerSocket)
closeServerSocketChannelSilently(serverChannel)
}
}

val executor = ThreadUtils.newDaemonSingleThreadExecutor("stateConnectionListenerThread")
val executionContext = ExecutionContext.fromExecutor(executor)

executionContext.execute(
new TransformWithStateInPandasStateServer(stateServerSocket, processorHandle,
new TransformWithStateInPandasStateServer(serverChannel, processorHandle,
groupingKeySchema, timeZoneId, errorOnDuplicatedFieldNames, largeVarTypes,
sqlConf.arrowTransformWithStateInPandasMaxRecordsPerBatch,
batchTimestampMs, eventTimeWatermarkForEviction))

context.addTaskCompletionListener[Unit] { _ =>
logInfo(log"completion listener called")
executor.shutdownNow()
closeServerSocketChannelSilently(stateServerSocket)
closeServerSocketChannelSilently(serverChannel)
val socketFile = Path.of(socketPath)
Files.deleteIfExists(socketFile)
}

super.compute(inputIterator, partitionIndex, context)
}

private def closeServerSocketChannelSilently(stateServerSocket: ServerSocket): Unit = {
private def closeServerSocketChannelSilently(serverChannel: UnixServerSocketChannel): Unit = {
try {
logInfo(log"closing the state server socket")
stateServerSocket.close()
serverChannel.close()
} catch {
case e: Exception =>
logError(log"failed to close state server socket", e)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,14 @@

package org.apache.spark.sql.execution.python

import java.io.{BufferedInputStream, BufferedOutputStream, DataInputStream, DataOutputStream, EOFException}
import java.net.ServerSocket
import java.io.{DataInputStream, DataOutputStream, EOFException}
import java.nio.channels.Channels
import java.time.Duration

import scala.collection.mutable

import com.google.protobuf.ByteString
import jnr.unixsocket.UnixServerSocketChannel
import org.apache.arrow.vector.VectorSchemaRoot
import org.apache.arrow.vector.ipc.ArrowStreamWriter

Expand Down Expand Up @@ -51,7 +52,7 @@ import org.apache.spark.util.Utils
* - Requests for managing state variables (e.g. valueState).
*/
class TransformWithStateInPandasStateServer(
stateServerSocket: ServerSocket,
serverChannel: UnixServerSocketChannel,
statefulProcessorHandle: StatefulProcessorHandleImpl,
groupingKeySchema: StructType,
timeZoneId: String,
Expand Down Expand Up @@ -134,14 +135,15 @@ class TransformWithStateInPandasStateServer(
} else new mutable.HashMap[String, Iterator[Long]]()

def run(): Unit = {
val listeningSocket = stateServerSocket.accept()
val channel = serverChannel.accept()
inputStream = new DataInputStream(
new BufferedInputStream(listeningSocket.getInputStream))
Channels.newInputStream(channel)
)
outputStream = new DataOutputStream(
new BufferedOutputStream(listeningSocket.getOutputStream)
Channels.newOutputStream(channel)
)

while (listeningSocket.isConnected &&
while (channel.isConnected &&
statefulProcessorHandle.getHandleState != StatefulProcessorHandleState.CLOSED) {
try {
val version = inputStream.readInt()
Expand Down Expand Up @@ -758,3 +760,12 @@ case class MapStateInfo(
keySerializer: ExpressionEncoder.Serializer[Row],
valueDeserializer: ExpressionEncoder.Deserializer[Row],
valueSerializer: ExpressionEncoder.Serializer[Row])

object TransformWithStateInPandasStateServer {
@volatile private var id = 0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIUC, TransformWithStateInPandasPythonRunner is instantiated per executor so they are in different JVMs, and TransformWithStateInPandasStateServer is initialized from each TransformWithStateInPandasPythonRunner once. What could be the scenario for handling multiple threads here?

Copy link
Contributor Author

@bogao007 bogao007 Nov 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's no much difference comparing to TCP socket. For UDS, we create socket files (we use different ports for TCP) for each server thread connection, and the files are named with different serverId which got updated incrementally. At the end we do cleanup to delete these socket files.


def allocateServerId(): Int = synchronized {
id = id + 1
return id
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,11 @@
package org.apache.spark.sql.execution.python

import java.io.DataOutputStream
import java.net.ServerSocket

import scala.collection.mutable

import com.google.protobuf.ByteString
import jnr.unixsocket.UnixServerSocketChannel
import org.mockito.ArgumentMatchers.{any, argThat}
import org.mockito.Mockito.{mock, times, verify, when}
import org.scalatest.BeforeAndAfterEach
Expand All @@ -39,7 +39,7 @@ import org.apache.spark.sql.types.{IntegerType, StructField, StructType}
class TransformWithStateInPandasStateServerSuite extends SparkFunSuite with BeforeAndAfterEach {
val stateName = "test"
val iteratorId = "testId"
val serverSocket: ServerSocket = mock(classOf[ServerSocket])
val serverChannel: UnixServerSocketChannel = mock(classOf[UnixServerSocketChannel])
val groupingKeySchema: StructType = StructType(Seq())
val stateSchema: StructType = StructType(Array(StructField("value", IntegerType)))
// Below byte array is a serialized row with a single integer value 1.
Expand Down Expand Up @@ -96,7 +96,7 @@ class TransformWithStateInPandasStateServerSuite extends SparkFunSuite with Befo
arrowStreamWriter = mock(classOf[BaseStreamingArrowWriter])
batchTimestampMs = mock(classOf[Option[Long]])
eventTimeWatermarkForEviction = mock(classOf[Option[Long]])
stateServer = new TransformWithStateInPandasStateServer(serverSocket,
stateServer = new TransformWithStateInPandasStateServer(serverChannel,
statefulProcessorHandle, groupingKeySchema, "", false, false, 2,
batchTimestampMs, eventTimeWatermarkForEviction,
outputStream, valueStateMap, transformWithStateInPandasDeserializer, arrowStreamWriter,
Expand Down Expand Up @@ -271,7 +271,7 @@ class TransformWithStateInPandasStateServerSuite extends SparkFunSuite with Befo
.setListStateGet(ListStateGet.newBuilder().setIteratorId(iteratorId).build()).build()
val iteratorMap = mutable.HashMap[String, Iterator[Row]](iteratorId ->
Iterator(getIntegerRow(1), getIntegerRow(2), getIntegerRow(3), getIntegerRow(4)))
stateServer = new TransformWithStateInPandasStateServer(serverSocket,
stateServer = new TransformWithStateInPandasStateServer(serverChannel,
statefulProcessorHandle, groupingKeySchema, "", false, false,
maxRecordsPerBatch, batchTimestampMs, eventTimeWatermarkForEviction, outputStream,
valueStateMap, transformWithStateInPandasDeserializer, arrowStreamWriter,
Expand All @@ -295,7 +295,7 @@ class TransformWithStateInPandasStateServerSuite extends SparkFunSuite with Befo
val message = ListStateCall.newBuilder().setStateName(stateName)
.setListStateGet(ListStateGet.newBuilder().setIteratorId(iteratorId).build()).build()
val iteratorMap: mutable.HashMap[String, Iterator[Row]] = mutable.HashMap()
stateServer = new TransformWithStateInPandasStateServer(serverSocket,
stateServer = new TransformWithStateInPandasStateServer(serverChannel,
statefulProcessorHandle, groupingKeySchema, "", false, false,
maxRecordsPerBatch, batchTimestampMs, eventTimeWatermarkForEviction, outputStream,
valueStateMap, transformWithStateInPandasDeserializer, arrowStreamWriter,
Expand Down Expand Up @@ -384,7 +384,7 @@ class TransformWithStateInPandasStateServerSuite extends SparkFunSuite with Befo
val keyValueIteratorMap = mutable.HashMap[String, Iterator[(Row, Row)]](iteratorId ->
Iterator((getIntegerRow(1), getIntegerRow(1)), (getIntegerRow(2), getIntegerRow(2)),
(getIntegerRow(3), getIntegerRow(3)), (getIntegerRow(4), getIntegerRow(4))))
stateServer = new TransformWithStateInPandasStateServer(serverSocket,
stateServer = new TransformWithStateInPandasStateServer(serverChannel,
statefulProcessorHandle, groupingKeySchema, "", false, false,
maxRecordsPerBatch, batchTimestampMs, eventTimeWatermarkForEviction, outputStream,
valueStateMap, transformWithStateInPandasDeserializer, arrowStreamWriter,
Expand All @@ -408,7 +408,7 @@ class TransformWithStateInPandasStateServerSuite extends SparkFunSuite with Befo
val message = MapStateCall.newBuilder().setStateName(stateName)
.setIterator(StateMessage.Iterator.newBuilder().setIteratorId(iteratorId).build()).build()
val keyValueIteratorMap: mutable.HashMap[String, Iterator[(Row, Row)]] = mutable.HashMap()
stateServer = new TransformWithStateInPandasStateServer(serverSocket,
stateServer = new TransformWithStateInPandasStateServer(serverChannel,
statefulProcessorHandle, groupingKeySchema, "", false, false,
maxRecordsPerBatch, batchTimestampMs, eventTimeWatermarkForEviction,
outputStream, valueStateMap, transformWithStateInPandasDeserializer,
Expand Down Expand Up @@ -437,7 +437,7 @@ class TransformWithStateInPandasStateServerSuite extends SparkFunSuite with Befo
val message = MapStateCall.newBuilder().setStateName(stateName)
.setKeys(Keys.newBuilder().setIteratorId(iteratorId).build()).build()
val iteratorMap: mutable.HashMap[String, Iterator[Row]] = mutable.HashMap()
stateServer = new TransformWithStateInPandasStateServer(serverSocket,
stateServer = new TransformWithStateInPandasStateServer(serverChannel,
statefulProcessorHandle, groupingKeySchema, "", false, false,
maxRecordsPerBatch, batchTimestampMs, eventTimeWatermarkForEviction,
outputStream, valueStateMap, transformWithStateInPandasDeserializer,
Expand Down Expand Up @@ -465,7 +465,7 @@ class TransformWithStateInPandasStateServerSuite extends SparkFunSuite with Befo
val message = MapStateCall.newBuilder().setStateName(stateName)
.setValues(Values.newBuilder().setIteratorId(iteratorId).build()).build()
val iteratorMap: mutable.HashMap[String, Iterator[Row]] = mutable.HashMap()
stateServer = new TransformWithStateInPandasStateServer(serverSocket,
stateServer = new TransformWithStateInPandasStateServer(serverChannel,
statefulProcessorHandle, groupingKeySchema, "", false, false,
maxRecordsPerBatch, batchTimestampMs, eventTimeWatermarkForEviction, outputStream,
valueStateMap, transformWithStateInPandasDeserializer,
Expand Down Expand Up @@ -561,7 +561,7 @@ class TransformWithStateInPandasStateServerSuite extends SparkFunSuite with Befo
.setList(ListTimers.newBuilder().setIteratorId("non-exist").build())
.build()
).build()
stateServer = new TransformWithStateInPandasStateServer(serverSocket,
stateServer = new TransformWithStateInPandasStateServer(serverChannel,
statefulProcessorHandle, groupingKeySchema, "", false, false,
2, batchTimestampMs, eventTimeWatermarkForEviction, outputStream,
valueStateMap, transformWithStateInPandasDeserializer,
Expand Down