diff --git a/dev/deps/spark-deps-hadoop-3-hive-2.3 b/dev/deps/spark-deps-hadoop-3-hive-2.3
index a866cf7b725c..e3c090481ea4 100644
--- a/dev/deps/spark-deps-hadoop-3-hive-2.3
+++ b/dev/deps/spark-deps-hadoop-3-hive-2.3
@@ -21,6 +21,11 @@ arrow-memory-core/18.0.0//arrow-memory-core-18.0.0.jar
arrow-memory-netty-buffer-patch/18.0.0//arrow-memory-netty-buffer-patch-18.0.0.jar
arrow-memory-netty/18.0.0//arrow-memory-netty-18.0.0.jar
arrow-vector/18.0.0//arrow-vector-18.0.0.jar
+asm-analysis/9.2//asm-analysis-9.2.jar
+asm-commons/9.2//asm-commons-9.2.jar
+asm-tree/9.2//asm-tree-9.2.jar
+asm-util/9.2//asm-util-9.2.jar
+asm/9.2//asm-9.2.jar
audience-annotations/0.12.0//audience-annotations-0.12.0.jar
avro-ipc/1.12.0//avro-ipc-1.12.0.jar
avro-mapred/1.12.0//avro-mapred-1.12.0.jar
@@ -142,10 +147,18 @@ jersey-server/3.0.16//jersey-server-3.0.16.jar
jettison/1.5.4//jettison-1.5.4.jar
jetty-util-ajax/11.0.24//jetty-util-ajax-11.0.24.jar
jetty-util/11.0.24//jetty-util-11.0.24.jar
+jffi/1.3.9//jffi-1.3.9.jar
+jffi/1.3.9/native/jffi-1.3.9-native.jar
jjwt-api/0.12.6//jjwt-api-0.12.6.jar
jline/2.14.6//jline-2.14.6.jar
jline/3.26.3//jline-3.26.3.jar
jna/5.14.0//jna-5.14.0.jar
+jnr-a64asm/1.0.0//jnr-a64asm-1.0.0.jar
+jnr-constants/0.10.3//jnr-constants-0.10.3.jar
+jnr-enxio/0.32.13//jnr-enxio-0.32.13.jar
+jnr-ffi/2.2.11//jnr-ffi-2.2.11.jar
+jnr-unixsocket/0.38.18//jnr-unixsocket-0.38.18.jar
+jnr-x86asm/1.0.2//jnr-x86asm-1.0.2.jar
joda-time/2.13.0//joda-time-2.13.0.jar
jodd-core/3.5.2//jodd-core-3.5.2.jar
jpam/1.1//jpam-1.1.jar
diff --git a/python/pyspark/sql/streaming/stateful_processor_api_client.py b/python/pyspark/sql/streaming/stateful_processor_api_client.py
index 353f75e26796..b122915bb78e 100644
--- a/python/pyspark/sql/streaming/stateful_processor_api_client.py
+++ b/python/pyspark/sql/streaming/stateful_processor_api_client.py
@@ -48,10 +48,11 @@ class StatefulProcessorHandleState(Enum):
class StatefulProcessorApiClient:
- def __init__(self, state_server_port: int, key_schema: StructType) -> None:
+ def __init__(self, state_server_id: int, key_schema: StructType) -> None:
self.key_schema = key_schema
- self._client_socket = socket.socket()
- self._client_socket.connect(("localhost", state_server_port))
+ self._client_socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
+ server_address = f"./uds_{state_server_id}.sock"
+ self._client_socket.connect(server_address)
self.sockfile = self._client_socket.makefile(
"rwb", int(os.environ.get("SPARK_BUFFER_SIZE", 65536))
)
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index 04f95e9f5264..7d4903d8b87b 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -1449,7 +1449,7 @@ def mapper(_, it):
def read_udfs(pickleSer, infile, eval_type):
runner_conf = {}
- state_server_port = None
+ state_server_id = None
key_schema = None
if eval_type in (
PythonEvalType.SQL_ARROW_BATCHED_UDF,
@@ -1481,7 +1481,7 @@ def read_udfs(pickleSer, infile, eval_type):
eval_type == PythonEvalType.SQL_TRANSFORM_WITH_STATE_PANDAS_UDF
or eval_type == PythonEvalType.SQL_TRANSFORM_WITH_STATE_PANDAS_INIT_STATE_UDF
):
- state_server_port = read_int(infile)
+ state_server_id = read_int(infile)
key_schema = StructType.fromJson(json.loads(utf8_deserializer.loads(infile)))
# NOTE: if timezone is set here, that implies respectSessionTimeZone is True
@@ -1695,7 +1695,7 @@ def mapper(a):
)
parsed_offsets = extract_key_value_indexes(arg_offsets)
ser.key_offsets = parsed_offsets[0][0]
- stateful_processor_api_client = StatefulProcessorApiClient(state_server_port, key_schema)
+ stateful_processor_api_client = StatefulProcessorApiClient(state_server_id, key_schema)
# Create function like this:
# mapper a: f([a[0]], [a[0], a[1]])
@@ -1728,7 +1728,7 @@ def values_gen():
parsed_offsets = extract_key_value_indexes(arg_offsets)
ser.key_offsets = parsed_offsets[0][0]
ser.init_key_offsets = parsed_offsets[1][0]
- stateful_processor_api_client = StatefulProcessorApiClient(state_server_port, key_schema)
+ stateful_processor_api_client = StatefulProcessorApiClient(state_server_id, key_schema)
def mapper(a):
key = a[0]
diff --git a/sql/core/pom.xml b/sql/core/pom.xml
index 47c9ca0ea7a1..2e79ce0c4f8d 100644
--- a/sql/core/pom.xml
+++ b/sql/core/pom.xml
@@ -252,6 +252,11 @@
htmlunit3-driver
test
+
+ com.github.jnr
+ jnr-unixsocket
+ 0.38.18
+
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasPythonRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasPythonRunner.scala
index c5980012124f..da324fe7472a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasPythonRunner.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasPythonRunner.scala
@@ -18,10 +18,12 @@
package org.apache.spark.sql.execution.python
import java.io.DataOutputStream
-import java.net.ServerSocket
+import java.nio.file.{Files, Path}
import scala.concurrent.ExecutionContext
+import jnr.unixsocket.UnixServerSocketChannel
+import jnr.unixsocket.UnixSocketAddress
import org.apache.arrow.vector.VectorSchemaRoot
import org.apache.arrow.vector.ipc.ArrowStreamWriter
@@ -181,7 +183,9 @@ abstract class TransformWithStateInPandasPythonBaseRunner[I](
protected val sqlConf = SQLConf.get
protected val arrowMaxRecordsPerBatch = sqlConf.arrowMaxRecordsPerBatch
- private var stateServerSocketPort: Int = 0
+ private val serverId = TransformWithStateInPandasStateServer.allocateServerId()
+
+ private val socketPath = s"./uds_$serverId.sock"
override protected val workerConf: Map[String, String] = initialWorkerConf +
(SQLConf.ARROW_EXECUTION_MAX_RECORDS_PER_BATCH.key -> arrowMaxRecordsPerBatch.toString)
@@ -195,8 +199,8 @@ abstract class TransformWithStateInPandasPythonBaseRunner[I](
override protected def handleMetadataBeforeExec(stream: DataOutputStream): Unit = {
super.handleMetadataBeforeExec(stream)
- // Also write the port number for state server
- stream.writeInt(stateServerSocketPort)
+ // Also write the service id for state server
+ stream.writeInt(serverId)
PythonRDD.writeUTF(groupingKeySchema.json, stream)
}
@@ -204,19 +208,21 @@ abstract class TransformWithStateInPandasPythonBaseRunner[I](
inputIterator: Iterator[I],
partitionIndex: Int,
context: TaskContext): Iterator[ColumnarBatch] = {
- var stateServerSocket: ServerSocket = null
+ var serverChannel: UnixServerSocketChannel = null
var failed = false
try {
- stateServerSocket = new ServerSocket( /* port = */ 0,
- /* backlog = */ 1)
- stateServerSocketPort = stateServerSocket.getLocalPort
+ val socketFile = Path.of(socketPath)
+ Files.deleteIfExists(socketFile)
+ val serverAddress = new UnixSocketAddress(socketPath)
+ serverChannel = UnixServerSocketChannel.open()
+ serverChannel.socket().bind(serverAddress)
} catch {
case e: Throwable =>
failed = true
throw e
} finally {
if (failed) {
- closeServerSocketChannelSilently(stateServerSocket)
+ closeServerSocketChannelSilently(serverChannel)
}
}
@@ -224,7 +230,7 @@ abstract class TransformWithStateInPandasPythonBaseRunner[I](
val executionContext = ExecutionContext.fromExecutor(executor)
executionContext.execute(
- new TransformWithStateInPandasStateServer(stateServerSocket, processorHandle,
+ new TransformWithStateInPandasStateServer(serverChannel, processorHandle,
groupingKeySchema, timeZoneId, errorOnDuplicatedFieldNames, largeVarTypes,
sqlConf.arrowTransformWithStateInPandasMaxRecordsPerBatch,
batchTimestampMs, eventTimeWatermarkForEviction))
@@ -232,16 +238,18 @@ abstract class TransformWithStateInPandasPythonBaseRunner[I](
context.addTaskCompletionListener[Unit] { _ =>
logInfo(log"completion listener called")
executor.shutdownNow()
- closeServerSocketChannelSilently(stateServerSocket)
+ closeServerSocketChannelSilently(serverChannel)
+ val socketFile = Path.of(socketPath)
+ Files.deleteIfExists(socketFile)
}
super.compute(inputIterator, partitionIndex, context)
}
- private def closeServerSocketChannelSilently(stateServerSocket: ServerSocket): Unit = {
+ private def closeServerSocketChannelSilently(serverChannel: UnixServerSocketChannel): Unit = {
try {
logInfo(log"closing the state server socket")
- stateServerSocket.close()
+ serverChannel.close()
} catch {
case e: Exception =>
logError(log"failed to close state server socket", e)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServer.scala
index 0373c8607ff2..ac15556596d7 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServer.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServer.scala
@@ -17,13 +17,14 @@
package org.apache.spark.sql.execution.python
-import java.io.{BufferedInputStream, BufferedOutputStream, DataInputStream, DataOutputStream, EOFException}
-import java.net.ServerSocket
+import java.io.{DataInputStream, DataOutputStream, EOFException}
+import java.nio.channels.Channels
import java.time.Duration
import scala.collection.mutable
import com.google.protobuf.ByteString
+import jnr.unixsocket.UnixServerSocketChannel
import org.apache.arrow.vector.VectorSchemaRoot
import org.apache.arrow.vector.ipc.ArrowStreamWriter
@@ -51,7 +52,7 @@ import org.apache.spark.util.Utils
* - Requests for managing state variables (e.g. valueState).
*/
class TransformWithStateInPandasStateServer(
- stateServerSocket: ServerSocket,
+ serverChannel: UnixServerSocketChannel,
statefulProcessorHandle: StatefulProcessorHandleImpl,
groupingKeySchema: StructType,
timeZoneId: String,
@@ -134,14 +135,15 @@ class TransformWithStateInPandasStateServer(
} else new mutable.HashMap[String, Iterator[Long]]()
def run(): Unit = {
- val listeningSocket = stateServerSocket.accept()
+ val channel = serverChannel.accept()
inputStream = new DataInputStream(
- new BufferedInputStream(listeningSocket.getInputStream))
+ Channels.newInputStream(channel)
+ )
outputStream = new DataOutputStream(
- new BufferedOutputStream(listeningSocket.getOutputStream)
+ Channels.newOutputStream(channel)
)
- while (listeningSocket.isConnected &&
+ while (channel.isConnected &&
statefulProcessorHandle.getHandleState != StatefulProcessorHandleState.CLOSED) {
try {
val version = inputStream.readInt()
@@ -758,3 +760,12 @@ case class MapStateInfo(
keySerializer: ExpressionEncoder.Serializer[Row],
valueDeserializer: ExpressionEncoder.Deserializer[Row],
valueSerializer: ExpressionEncoder.Serializer[Row])
+
+object TransformWithStateInPandasStateServer {
+ @volatile private var id = 0
+
+ def allocateServerId(): Int = synchronized {
+ id = id + 1
+ return id
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServerSuite.scala
index e05264825f77..05f6ecb5e23a 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServerSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServerSuite.scala
@@ -17,11 +17,11 @@
package org.apache.spark.sql.execution.python
import java.io.DataOutputStream
-import java.net.ServerSocket
import scala.collection.mutable
import com.google.protobuf.ByteString
+import jnr.unixsocket.UnixServerSocketChannel
import org.mockito.ArgumentMatchers.{any, argThat}
import org.mockito.Mockito.{mock, times, verify, when}
import org.scalatest.BeforeAndAfterEach
@@ -39,7 +39,7 @@ import org.apache.spark.sql.types.{IntegerType, StructField, StructType}
class TransformWithStateInPandasStateServerSuite extends SparkFunSuite with BeforeAndAfterEach {
val stateName = "test"
val iteratorId = "testId"
- val serverSocket: ServerSocket = mock(classOf[ServerSocket])
+ val serverChannel: UnixServerSocketChannel = mock(classOf[UnixServerSocketChannel])
val groupingKeySchema: StructType = StructType(Seq())
val stateSchema: StructType = StructType(Array(StructField("value", IntegerType)))
// Below byte array is a serialized row with a single integer value 1.
@@ -96,7 +96,7 @@ class TransformWithStateInPandasStateServerSuite extends SparkFunSuite with Befo
arrowStreamWriter = mock(classOf[BaseStreamingArrowWriter])
batchTimestampMs = mock(classOf[Option[Long]])
eventTimeWatermarkForEviction = mock(classOf[Option[Long]])
- stateServer = new TransformWithStateInPandasStateServer(serverSocket,
+ stateServer = new TransformWithStateInPandasStateServer(serverChannel,
statefulProcessorHandle, groupingKeySchema, "", false, false, 2,
batchTimestampMs, eventTimeWatermarkForEviction,
outputStream, valueStateMap, transformWithStateInPandasDeserializer, arrowStreamWriter,
@@ -271,7 +271,7 @@ class TransformWithStateInPandasStateServerSuite extends SparkFunSuite with Befo
.setListStateGet(ListStateGet.newBuilder().setIteratorId(iteratorId).build()).build()
val iteratorMap = mutable.HashMap[String, Iterator[Row]](iteratorId ->
Iterator(getIntegerRow(1), getIntegerRow(2), getIntegerRow(3), getIntegerRow(4)))
- stateServer = new TransformWithStateInPandasStateServer(serverSocket,
+ stateServer = new TransformWithStateInPandasStateServer(serverChannel,
statefulProcessorHandle, groupingKeySchema, "", false, false,
maxRecordsPerBatch, batchTimestampMs, eventTimeWatermarkForEviction, outputStream,
valueStateMap, transformWithStateInPandasDeserializer, arrowStreamWriter,
@@ -295,7 +295,7 @@ class TransformWithStateInPandasStateServerSuite extends SparkFunSuite with Befo
val message = ListStateCall.newBuilder().setStateName(stateName)
.setListStateGet(ListStateGet.newBuilder().setIteratorId(iteratorId).build()).build()
val iteratorMap: mutable.HashMap[String, Iterator[Row]] = mutable.HashMap()
- stateServer = new TransformWithStateInPandasStateServer(serverSocket,
+ stateServer = new TransformWithStateInPandasStateServer(serverChannel,
statefulProcessorHandle, groupingKeySchema, "", false, false,
maxRecordsPerBatch, batchTimestampMs, eventTimeWatermarkForEviction, outputStream,
valueStateMap, transformWithStateInPandasDeserializer, arrowStreamWriter,
@@ -384,7 +384,7 @@ class TransformWithStateInPandasStateServerSuite extends SparkFunSuite with Befo
val keyValueIteratorMap = mutable.HashMap[String, Iterator[(Row, Row)]](iteratorId ->
Iterator((getIntegerRow(1), getIntegerRow(1)), (getIntegerRow(2), getIntegerRow(2)),
(getIntegerRow(3), getIntegerRow(3)), (getIntegerRow(4), getIntegerRow(4))))
- stateServer = new TransformWithStateInPandasStateServer(serverSocket,
+ stateServer = new TransformWithStateInPandasStateServer(serverChannel,
statefulProcessorHandle, groupingKeySchema, "", false, false,
maxRecordsPerBatch, batchTimestampMs, eventTimeWatermarkForEviction, outputStream,
valueStateMap, transformWithStateInPandasDeserializer, arrowStreamWriter,
@@ -408,7 +408,7 @@ class TransformWithStateInPandasStateServerSuite extends SparkFunSuite with Befo
val message = MapStateCall.newBuilder().setStateName(stateName)
.setIterator(StateMessage.Iterator.newBuilder().setIteratorId(iteratorId).build()).build()
val keyValueIteratorMap: mutable.HashMap[String, Iterator[(Row, Row)]] = mutable.HashMap()
- stateServer = new TransformWithStateInPandasStateServer(serverSocket,
+ stateServer = new TransformWithStateInPandasStateServer(serverChannel,
statefulProcessorHandle, groupingKeySchema, "", false, false,
maxRecordsPerBatch, batchTimestampMs, eventTimeWatermarkForEviction,
outputStream, valueStateMap, transformWithStateInPandasDeserializer,
@@ -437,7 +437,7 @@ class TransformWithStateInPandasStateServerSuite extends SparkFunSuite with Befo
val message = MapStateCall.newBuilder().setStateName(stateName)
.setKeys(Keys.newBuilder().setIteratorId(iteratorId).build()).build()
val iteratorMap: mutable.HashMap[String, Iterator[Row]] = mutable.HashMap()
- stateServer = new TransformWithStateInPandasStateServer(serverSocket,
+ stateServer = new TransformWithStateInPandasStateServer(serverChannel,
statefulProcessorHandle, groupingKeySchema, "", false, false,
maxRecordsPerBatch, batchTimestampMs, eventTimeWatermarkForEviction,
outputStream, valueStateMap, transformWithStateInPandasDeserializer,
@@ -465,7 +465,7 @@ class TransformWithStateInPandasStateServerSuite extends SparkFunSuite with Befo
val message = MapStateCall.newBuilder().setStateName(stateName)
.setValues(Values.newBuilder().setIteratorId(iteratorId).build()).build()
val iteratorMap: mutable.HashMap[String, Iterator[Row]] = mutable.HashMap()
- stateServer = new TransformWithStateInPandasStateServer(serverSocket,
+ stateServer = new TransformWithStateInPandasStateServer(serverChannel,
statefulProcessorHandle, groupingKeySchema, "", false, false,
maxRecordsPerBatch, batchTimestampMs, eventTimeWatermarkForEviction, outputStream,
valueStateMap, transformWithStateInPandasDeserializer,
@@ -561,7 +561,7 @@ class TransformWithStateInPandasStateServerSuite extends SparkFunSuite with Befo
.setList(ListTimers.newBuilder().setIteratorId("non-exist").build())
.build()
).build()
- stateServer = new TransformWithStateInPandasStateServer(serverSocket,
+ stateServer = new TransformWithStateInPandasStateServer(serverChannel,
statefulProcessorHandle, groupingKeySchema, "", false, false,
2, batchTimestampMs, eventTimeWatermarkForEviction, outputStream,
valueStateMap, transformWithStateInPandasDeserializer,