From dba5d0c5302411dc968fdf3dfa2698419d7fb94a Mon Sep 17 00:00:00 2001 From: bogao007 Date: Mon, 18 Nov 2024 11:40:46 -0800 Subject: [PATCH 1/2] Use UDS for JVM and Python worker communication --- .../stateful_processor_api_client.py | 7 ++-- python/pyspark/worker.py | 8 ++--- sql/core/pom.xml | 5 +++ ...ansformWithStateInPandasPythonRunner.scala | 34 ++++++++++++------- ...ransformWithStateInPandasStateServer.scala | 25 ++++++++++---- ...ormWithStateInPandasStateServerSuite.scala | 20 +++++------ 6 files changed, 62 insertions(+), 37 deletions(-) diff --git a/python/pyspark/sql/streaming/stateful_processor_api_client.py b/python/pyspark/sql/streaming/stateful_processor_api_client.py index 353f75e26796..b122915bb78e 100644 --- a/python/pyspark/sql/streaming/stateful_processor_api_client.py +++ b/python/pyspark/sql/streaming/stateful_processor_api_client.py @@ -48,10 +48,11 @@ class StatefulProcessorHandleState(Enum): class StatefulProcessorApiClient: - def __init__(self, state_server_port: int, key_schema: StructType) -> None: + def __init__(self, state_server_id: int, key_schema: StructType) -> None: self.key_schema = key_schema - self._client_socket = socket.socket() - self._client_socket.connect(("localhost", state_server_port)) + self._client_socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + server_address = f"./uds_{state_server_id}.sock" + self._client_socket.connect(server_address) self.sockfile = self._client_socket.makefile( "rwb", int(os.environ.get("SPARK_BUFFER_SIZE", 65536)) ) diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 04f95e9f5264..7d4903d8b87b 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -1449,7 +1449,7 @@ def mapper(_, it): def read_udfs(pickleSer, infile, eval_type): runner_conf = {} - state_server_port = None + state_server_id = None key_schema = None if eval_type in ( PythonEvalType.SQL_ARROW_BATCHED_UDF, @@ -1481,7 +1481,7 @@ def read_udfs(pickleSer, infile, eval_type): eval_type == PythonEvalType.SQL_TRANSFORM_WITH_STATE_PANDAS_UDF or eval_type == PythonEvalType.SQL_TRANSFORM_WITH_STATE_PANDAS_INIT_STATE_UDF ): - state_server_port = read_int(infile) + state_server_id = read_int(infile) key_schema = StructType.fromJson(json.loads(utf8_deserializer.loads(infile))) # NOTE: if timezone is set here, that implies respectSessionTimeZone is True @@ -1695,7 +1695,7 @@ def mapper(a): ) parsed_offsets = extract_key_value_indexes(arg_offsets) ser.key_offsets = parsed_offsets[0][0] - stateful_processor_api_client = StatefulProcessorApiClient(state_server_port, key_schema) + stateful_processor_api_client = StatefulProcessorApiClient(state_server_id, key_schema) # Create function like this: # mapper a: f([a[0]], [a[0], a[1]]) @@ -1728,7 +1728,7 @@ def values_gen(): parsed_offsets = extract_key_value_indexes(arg_offsets) ser.key_offsets = parsed_offsets[0][0] ser.init_key_offsets = parsed_offsets[1][0] - stateful_processor_api_client = StatefulProcessorApiClient(state_server_port, key_schema) + stateful_processor_api_client = StatefulProcessorApiClient(state_server_id, key_schema) def mapper(a): key = a[0] diff --git a/sql/core/pom.xml b/sql/core/pom.xml index 47c9ca0ea7a1..2e79ce0c4f8d 100644 --- a/sql/core/pom.xml +++ b/sql/core/pom.xml @@ -252,6 +252,11 @@ htmlunit3-driver test + + com.github.jnr + jnr-unixsocket + 0.38.18 + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasPythonRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasPythonRunner.scala index c5980012124f..da324fe7472a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasPythonRunner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasPythonRunner.scala @@ -18,10 +18,12 @@ package org.apache.spark.sql.execution.python import java.io.DataOutputStream -import java.net.ServerSocket +import java.nio.file.{Files, Path} import scala.concurrent.ExecutionContext +import jnr.unixsocket.UnixServerSocketChannel +import jnr.unixsocket.UnixSocketAddress import org.apache.arrow.vector.VectorSchemaRoot import org.apache.arrow.vector.ipc.ArrowStreamWriter @@ -181,7 +183,9 @@ abstract class TransformWithStateInPandasPythonBaseRunner[I]( protected val sqlConf = SQLConf.get protected val arrowMaxRecordsPerBatch = sqlConf.arrowMaxRecordsPerBatch - private var stateServerSocketPort: Int = 0 + private val serverId = TransformWithStateInPandasStateServer.allocateServerId() + + private val socketPath = s"./uds_$serverId.sock" override protected val workerConf: Map[String, String] = initialWorkerConf + (SQLConf.ARROW_EXECUTION_MAX_RECORDS_PER_BATCH.key -> arrowMaxRecordsPerBatch.toString) @@ -195,8 +199,8 @@ abstract class TransformWithStateInPandasPythonBaseRunner[I]( override protected def handleMetadataBeforeExec(stream: DataOutputStream): Unit = { super.handleMetadataBeforeExec(stream) - // Also write the port number for state server - stream.writeInt(stateServerSocketPort) + // Also write the service id for state server + stream.writeInt(serverId) PythonRDD.writeUTF(groupingKeySchema.json, stream) } @@ -204,19 +208,21 @@ abstract class TransformWithStateInPandasPythonBaseRunner[I]( inputIterator: Iterator[I], partitionIndex: Int, context: TaskContext): Iterator[ColumnarBatch] = { - var stateServerSocket: ServerSocket = null + var serverChannel: UnixServerSocketChannel = null var failed = false try { - stateServerSocket = new ServerSocket( /* port = */ 0, - /* backlog = */ 1) - stateServerSocketPort = stateServerSocket.getLocalPort + val socketFile = Path.of(socketPath) + Files.deleteIfExists(socketFile) + val serverAddress = new UnixSocketAddress(socketPath) + serverChannel = UnixServerSocketChannel.open() + serverChannel.socket().bind(serverAddress) } catch { case e: Throwable => failed = true throw e } finally { if (failed) { - closeServerSocketChannelSilently(stateServerSocket) + closeServerSocketChannelSilently(serverChannel) } } @@ -224,7 +230,7 @@ abstract class TransformWithStateInPandasPythonBaseRunner[I]( val executionContext = ExecutionContext.fromExecutor(executor) executionContext.execute( - new TransformWithStateInPandasStateServer(stateServerSocket, processorHandle, + new TransformWithStateInPandasStateServer(serverChannel, processorHandle, groupingKeySchema, timeZoneId, errorOnDuplicatedFieldNames, largeVarTypes, sqlConf.arrowTransformWithStateInPandasMaxRecordsPerBatch, batchTimestampMs, eventTimeWatermarkForEviction)) @@ -232,16 +238,18 @@ abstract class TransformWithStateInPandasPythonBaseRunner[I]( context.addTaskCompletionListener[Unit] { _ => logInfo(log"completion listener called") executor.shutdownNow() - closeServerSocketChannelSilently(stateServerSocket) + closeServerSocketChannelSilently(serverChannel) + val socketFile = Path.of(socketPath) + Files.deleteIfExists(socketFile) } super.compute(inputIterator, partitionIndex, context) } - private def closeServerSocketChannelSilently(stateServerSocket: ServerSocket): Unit = { + private def closeServerSocketChannelSilently(serverChannel: UnixServerSocketChannel): Unit = { try { logInfo(log"closing the state server socket") - stateServerSocket.close() + serverChannel.close() } catch { case e: Exception => logError(log"failed to close state server socket", e) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServer.scala index 0373c8607ff2..ac15556596d7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServer.scala @@ -17,13 +17,14 @@ package org.apache.spark.sql.execution.python -import java.io.{BufferedInputStream, BufferedOutputStream, DataInputStream, DataOutputStream, EOFException} -import java.net.ServerSocket +import java.io.{DataInputStream, DataOutputStream, EOFException} +import java.nio.channels.Channels import java.time.Duration import scala.collection.mutable import com.google.protobuf.ByteString +import jnr.unixsocket.UnixServerSocketChannel import org.apache.arrow.vector.VectorSchemaRoot import org.apache.arrow.vector.ipc.ArrowStreamWriter @@ -51,7 +52,7 @@ import org.apache.spark.util.Utils * - Requests for managing state variables (e.g. valueState). */ class TransformWithStateInPandasStateServer( - stateServerSocket: ServerSocket, + serverChannel: UnixServerSocketChannel, statefulProcessorHandle: StatefulProcessorHandleImpl, groupingKeySchema: StructType, timeZoneId: String, @@ -134,14 +135,15 @@ class TransformWithStateInPandasStateServer( } else new mutable.HashMap[String, Iterator[Long]]() def run(): Unit = { - val listeningSocket = stateServerSocket.accept() + val channel = serverChannel.accept() inputStream = new DataInputStream( - new BufferedInputStream(listeningSocket.getInputStream)) + Channels.newInputStream(channel) + ) outputStream = new DataOutputStream( - new BufferedOutputStream(listeningSocket.getOutputStream) + Channels.newOutputStream(channel) ) - while (listeningSocket.isConnected && + while (channel.isConnected && statefulProcessorHandle.getHandleState != StatefulProcessorHandleState.CLOSED) { try { val version = inputStream.readInt() @@ -758,3 +760,12 @@ case class MapStateInfo( keySerializer: ExpressionEncoder.Serializer[Row], valueDeserializer: ExpressionEncoder.Deserializer[Row], valueSerializer: ExpressionEncoder.Serializer[Row]) + +object TransformWithStateInPandasStateServer { + @volatile private var id = 0 + + def allocateServerId(): Int = synchronized { + id = id + 1 + return id + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServerSuite.scala index e05264825f77..05f6ecb5e23a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServerSuite.scala @@ -17,11 +17,11 @@ package org.apache.spark.sql.execution.python import java.io.DataOutputStream -import java.net.ServerSocket import scala.collection.mutable import com.google.protobuf.ByteString +import jnr.unixsocket.UnixServerSocketChannel import org.mockito.ArgumentMatchers.{any, argThat} import org.mockito.Mockito.{mock, times, verify, when} import org.scalatest.BeforeAndAfterEach @@ -39,7 +39,7 @@ import org.apache.spark.sql.types.{IntegerType, StructField, StructType} class TransformWithStateInPandasStateServerSuite extends SparkFunSuite with BeforeAndAfterEach { val stateName = "test" val iteratorId = "testId" - val serverSocket: ServerSocket = mock(classOf[ServerSocket]) + val serverChannel: UnixServerSocketChannel = mock(classOf[UnixServerSocketChannel]) val groupingKeySchema: StructType = StructType(Seq()) val stateSchema: StructType = StructType(Array(StructField("value", IntegerType))) // Below byte array is a serialized row with a single integer value 1. @@ -96,7 +96,7 @@ class TransformWithStateInPandasStateServerSuite extends SparkFunSuite with Befo arrowStreamWriter = mock(classOf[BaseStreamingArrowWriter]) batchTimestampMs = mock(classOf[Option[Long]]) eventTimeWatermarkForEviction = mock(classOf[Option[Long]]) - stateServer = new TransformWithStateInPandasStateServer(serverSocket, + stateServer = new TransformWithStateInPandasStateServer(serverChannel, statefulProcessorHandle, groupingKeySchema, "", false, false, 2, batchTimestampMs, eventTimeWatermarkForEviction, outputStream, valueStateMap, transformWithStateInPandasDeserializer, arrowStreamWriter, @@ -271,7 +271,7 @@ class TransformWithStateInPandasStateServerSuite extends SparkFunSuite with Befo .setListStateGet(ListStateGet.newBuilder().setIteratorId(iteratorId).build()).build() val iteratorMap = mutable.HashMap[String, Iterator[Row]](iteratorId -> Iterator(getIntegerRow(1), getIntegerRow(2), getIntegerRow(3), getIntegerRow(4))) - stateServer = new TransformWithStateInPandasStateServer(serverSocket, + stateServer = new TransformWithStateInPandasStateServer(serverChannel, statefulProcessorHandle, groupingKeySchema, "", false, false, maxRecordsPerBatch, batchTimestampMs, eventTimeWatermarkForEviction, outputStream, valueStateMap, transformWithStateInPandasDeserializer, arrowStreamWriter, @@ -295,7 +295,7 @@ class TransformWithStateInPandasStateServerSuite extends SparkFunSuite with Befo val message = ListStateCall.newBuilder().setStateName(stateName) .setListStateGet(ListStateGet.newBuilder().setIteratorId(iteratorId).build()).build() val iteratorMap: mutable.HashMap[String, Iterator[Row]] = mutable.HashMap() - stateServer = new TransformWithStateInPandasStateServer(serverSocket, + stateServer = new TransformWithStateInPandasStateServer(serverChannel, statefulProcessorHandle, groupingKeySchema, "", false, false, maxRecordsPerBatch, batchTimestampMs, eventTimeWatermarkForEviction, outputStream, valueStateMap, transformWithStateInPandasDeserializer, arrowStreamWriter, @@ -384,7 +384,7 @@ class TransformWithStateInPandasStateServerSuite extends SparkFunSuite with Befo val keyValueIteratorMap = mutable.HashMap[String, Iterator[(Row, Row)]](iteratorId -> Iterator((getIntegerRow(1), getIntegerRow(1)), (getIntegerRow(2), getIntegerRow(2)), (getIntegerRow(3), getIntegerRow(3)), (getIntegerRow(4), getIntegerRow(4)))) - stateServer = new TransformWithStateInPandasStateServer(serverSocket, + stateServer = new TransformWithStateInPandasStateServer(serverChannel, statefulProcessorHandle, groupingKeySchema, "", false, false, maxRecordsPerBatch, batchTimestampMs, eventTimeWatermarkForEviction, outputStream, valueStateMap, transformWithStateInPandasDeserializer, arrowStreamWriter, @@ -408,7 +408,7 @@ class TransformWithStateInPandasStateServerSuite extends SparkFunSuite with Befo val message = MapStateCall.newBuilder().setStateName(stateName) .setIterator(StateMessage.Iterator.newBuilder().setIteratorId(iteratorId).build()).build() val keyValueIteratorMap: mutable.HashMap[String, Iterator[(Row, Row)]] = mutable.HashMap() - stateServer = new TransformWithStateInPandasStateServer(serverSocket, + stateServer = new TransformWithStateInPandasStateServer(serverChannel, statefulProcessorHandle, groupingKeySchema, "", false, false, maxRecordsPerBatch, batchTimestampMs, eventTimeWatermarkForEviction, outputStream, valueStateMap, transformWithStateInPandasDeserializer, @@ -437,7 +437,7 @@ class TransformWithStateInPandasStateServerSuite extends SparkFunSuite with Befo val message = MapStateCall.newBuilder().setStateName(stateName) .setKeys(Keys.newBuilder().setIteratorId(iteratorId).build()).build() val iteratorMap: mutable.HashMap[String, Iterator[Row]] = mutable.HashMap() - stateServer = new TransformWithStateInPandasStateServer(serverSocket, + stateServer = new TransformWithStateInPandasStateServer(serverChannel, statefulProcessorHandle, groupingKeySchema, "", false, false, maxRecordsPerBatch, batchTimestampMs, eventTimeWatermarkForEviction, outputStream, valueStateMap, transformWithStateInPandasDeserializer, @@ -465,7 +465,7 @@ class TransformWithStateInPandasStateServerSuite extends SparkFunSuite with Befo val message = MapStateCall.newBuilder().setStateName(stateName) .setValues(Values.newBuilder().setIteratorId(iteratorId).build()).build() val iteratorMap: mutable.HashMap[String, Iterator[Row]] = mutable.HashMap() - stateServer = new TransformWithStateInPandasStateServer(serverSocket, + stateServer = new TransformWithStateInPandasStateServer(serverChannel, statefulProcessorHandle, groupingKeySchema, "", false, false, maxRecordsPerBatch, batchTimestampMs, eventTimeWatermarkForEviction, outputStream, valueStateMap, transformWithStateInPandasDeserializer, @@ -561,7 +561,7 @@ class TransformWithStateInPandasStateServerSuite extends SparkFunSuite with Befo .setList(ListTimers.newBuilder().setIteratorId("non-exist").build()) .build() ).build() - stateServer = new TransformWithStateInPandasStateServer(serverSocket, + stateServer = new TransformWithStateInPandasStateServer(serverChannel, statefulProcessorHandle, groupingKeySchema, "", false, false, 2, batchTimestampMs, eventTimeWatermarkForEviction, outputStream, valueStateMap, transformWithStateInPandasDeserializer, From 0ff3cdee331b0855607131e76eb1684e910bb5dd Mon Sep 17 00:00:00 2001 From: bogao007 Date: Mon, 18 Nov 2024 15:09:49 -0800 Subject: [PATCH 2/2] fix linter --- dev/deps/spark-deps-hadoop-3-hive-2.3 | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/dev/deps/spark-deps-hadoop-3-hive-2.3 b/dev/deps/spark-deps-hadoop-3-hive-2.3 index a866cf7b725c..e3c090481ea4 100644 --- a/dev/deps/spark-deps-hadoop-3-hive-2.3 +++ b/dev/deps/spark-deps-hadoop-3-hive-2.3 @@ -21,6 +21,11 @@ arrow-memory-core/18.0.0//arrow-memory-core-18.0.0.jar arrow-memory-netty-buffer-patch/18.0.0//arrow-memory-netty-buffer-patch-18.0.0.jar arrow-memory-netty/18.0.0//arrow-memory-netty-18.0.0.jar arrow-vector/18.0.0//arrow-vector-18.0.0.jar +asm-analysis/9.2//asm-analysis-9.2.jar +asm-commons/9.2//asm-commons-9.2.jar +asm-tree/9.2//asm-tree-9.2.jar +asm-util/9.2//asm-util-9.2.jar +asm/9.2//asm-9.2.jar audience-annotations/0.12.0//audience-annotations-0.12.0.jar avro-ipc/1.12.0//avro-ipc-1.12.0.jar avro-mapred/1.12.0//avro-mapred-1.12.0.jar @@ -142,10 +147,18 @@ jersey-server/3.0.16//jersey-server-3.0.16.jar jettison/1.5.4//jettison-1.5.4.jar jetty-util-ajax/11.0.24//jetty-util-ajax-11.0.24.jar jetty-util/11.0.24//jetty-util-11.0.24.jar +jffi/1.3.9//jffi-1.3.9.jar +jffi/1.3.9/native/jffi-1.3.9-native.jar jjwt-api/0.12.6//jjwt-api-0.12.6.jar jline/2.14.6//jline-2.14.6.jar jline/3.26.3//jline-3.26.3.jar jna/5.14.0//jna-5.14.0.jar +jnr-a64asm/1.0.0//jnr-a64asm-1.0.0.jar +jnr-constants/0.10.3//jnr-constants-0.10.3.jar +jnr-enxio/0.32.13//jnr-enxio-0.32.13.jar +jnr-ffi/2.2.11//jnr-ffi-2.2.11.jar +jnr-unixsocket/0.38.18//jnr-unixsocket-0.38.18.jar +jnr-x86asm/1.0.2//jnr-x86asm-1.0.2.jar joda-time/2.13.0//joda-time-2.13.0.jar jodd-core/3.5.2//jodd-core-3.5.2.jar jpam/1.1//jpam-1.1.jar