diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala index 3d885ffdb02d..63484c23a920 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala @@ -212,6 +212,8 @@ private[spark] abstract class BasePythonRunner[IN, OUT]( protected val hideTraceback: Boolean = false protected val simplifiedTraceback: Boolean = false + protected val runnerConf: Map[String, String] = Map.empty + // All the Python functions should have the same exec, version and envvars. protected val envVars: java.util.Map[String, String] = funcs.head.funcs.head.envVars protected val pythonExec: String = funcs.head.funcs.head.pythonExec @@ -403,6 +405,17 @@ private[spark] abstract class BasePythonRunner[IN, OUT]( */ protected def writeCommand(dataOut: DataOutputStream): Unit + /** + * Writes worker configuration to the stream connected to the Python worker. + */ + protected def writeRunnerConf(dataOut: DataOutputStream): Unit = { + dataOut.writeInt(runnerConf.size) + for ((k, v) <- runnerConf) { + PythonWorkerUtils.writeUTF(k, dataOut) + PythonWorkerUtils.writeUTF(v, dataOut) + } + } + /** * Writes input data to the stream connected to the Python worker. * Returns true if any data was written to the stream, false if the input is exhausted. @@ -532,6 +545,7 @@ private[spark] abstract class BasePythonRunner[IN, OUT]( PythonWorkerUtils.writeBroadcasts(broadcastVars, worker, env, dataOut) dataOut.writeInt(evalType) + writeRunnerConf(dataOut) writeCommand(dataOut) dataOut.flush() diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 50e71fb6da9d..74a01f0659a5 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -1512,10 +1512,8 @@ def read_single_udf(pickleSer, infile, eval_type, runner_conf, udf_index, profil # It expects the UDTF to be in a specific format and performs various checks to # ensure the UDTF is valid. This function also prepares a mapper function for applying # the UDTF logic to input rows. -def read_udtf(pickleSer, infile, eval_type): +def read_udtf(pickleSer, infile, eval_type, runner_conf): if eval_type == PythonEvalType.SQL_ARROW_TABLE_UDF: - # Load conf used for arrow evaluation. - runner_conf = RunnerConf(infile) input_types = [ field.dataType for field in _parse_datatype_json_string(utf8_deserializer.loads(infile)) ] @@ -1530,7 +1528,6 @@ def read_udtf(pickleSer, infile, eval_type): else: ser = ArrowStreamUDTFSerializer() elif eval_type == PythonEvalType.SQL_ARROW_UDTF: - runner_conf = RunnerConf(infile) # Read the table argument offsets num_table_arg_offsets = read_int(infile) table_arg_offsets = [read_int(infile) for _ in range(num_table_arg_offsets)] @@ -1538,7 +1535,6 @@ def read_udtf(pickleSer, infile, eval_type): ser = ArrowStreamArrowUDTFSerializer(table_arg_offsets=table_arg_offsets) else: # Each row is a group so do not batch but send one by one. - runner_conf = RunnerConf() ser = BatchedSerializer(CPickleSerializer(), 1) # See 'PythonUDTFRunner.PythonUDFWriterThread.writeCommand' @@ -2686,7 +2682,7 @@ def mapper(_, it): return mapper, None, ser, ser -def read_udfs(pickleSer, infile, eval_type): +def read_udfs(pickleSer, infile, eval_type, runner_conf): state_server_port = None key_schema = None if eval_type in ( @@ -2714,9 +2710,6 @@ def read_udfs(pickleSer, infile, eval_type): PythonEvalType.SQL_TRANSFORM_WITH_STATE_PYTHON_ROW_UDF, PythonEvalType.SQL_TRANSFORM_WITH_STATE_PYTHON_ROW_INIT_STATE_UDF, ): - # Load conf used for pandas_udf evaluation - runner_conf = RunnerConf(infile) - state_object_schema = None if eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE: state_object_schema = StructType.fromJson(json.loads(utf8_deserializer.loads(infile))) @@ -2868,7 +2861,6 @@ def read_udfs(pickleSer, infile, eval_type): int_to_decimal_coercion_enabled=runner_conf.int_to_decimal_coercion_enabled, ) else: - runner_conf = RunnerConf() batch_size = int(os.environ.get("PYTHON_UDF_BATCH_SIZE", "100")) ser = BatchedSerializer(CPickleSerializer(), batch_size) @@ -3351,6 +3343,7 @@ def main(infile, outfile): _accumulatorRegistry.clear() eval_type = read_int(infile) + runner_conf = RunnerConf(infile) if eval_type == PythonEvalType.NON_UDF: func, profiler, deserializer, serializer = read_command(pickleSer, infile) elif eval_type in ( @@ -3358,9 +3351,13 @@ def main(infile, outfile): PythonEvalType.SQL_ARROW_TABLE_UDF, PythonEvalType.SQL_ARROW_UDTF, ): - func, profiler, deserializer, serializer = read_udtf(pickleSer, infile, eval_type) + func, profiler, deserializer, serializer = read_udtf( + pickleSer, infile, eval_type, runner_conf + ) else: - func, profiler, deserializer, serializer = read_udfs(pickleSer, infile, eval_type) + func, profiler, deserializer, serializer = read_udfs( + pickleSer, infile, eval_type, runner_conf + ) init_time = time.time() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala index f5f968ee9522..499fa99a2444 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala @@ -35,7 +35,6 @@ abstract class BaseArrowPythonRunner[IN, OUT <: AnyRef]( _schema: StructType, _timeZoneId: String, protected override val largeVarTypes: Boolean, - protected override val workerConf: Map[String, String], override val pythonMetrics: Map[String, SQLMetric], jobArtifactUUID: Option[String], sessionUUID: Option[String]) @@ -86,12 +85,11 @@ abstract class RowInputArrowPythonRunner( _schema: StructType, _timeZoneId: String, largeVarTypes: Boolean, - workerConf: Map[String, String], pythonMetrics: Map[String, SQLMetric], jobArtifactUUID: Option[String], sessionUUID: Option[String]) extends BaseArrowPythonRunner[Iterator[InternalRow], ColumnarBatch]( - funcs, evalType, argOffsets, _schema, _timeZoneId, largeVarTypes, workerConf, + funcs, evalType, argOffsets, _schema, _timeZoneId, largeVarTypes, pythonMetrics, jobArtifactUUID, sessionUUID) with BasicPythonArrowInput with BasicPythonArrowOutput @@ -106,13 +104,13 @@ class ArrowPythonRunner( _schema: StructType, _timeZoneId: String, largeVarTypes: Boolean, - workerConf: Map[String, String], + protected override val runnerConf: Map[String, String], pythonMetrics: Map[String, SQLMetric], jobArtifactUUID: Option[String], sessionUUID: Option[String], profiler: Option[String]) extends RowInputArrowPythonRunner( - funcs, evalType, argOffsets, _schema, _timeZoneId, largeVarTypes, workerConf, + funcs, evalType, argOffsets, _schema, _timeZoneId, largeVarTypes, pythonMetrics, jobArtifactUUID, sessionUUID) { override protected def writeUDF(dataOut: DataOutputStream): Unit = @@ -130,13 +128,13 @@ class ArrowPythonWithNamedArgumentRunner( _schema: StructType, _timeZoneId: String, largeVarTypes: Boolean, - workerConf: Map[String, String], + protected override val runnerConf: Map[String, String], pythonMetrics: Map[String, SQLMetric], jobArtifactUUID: Option[String], sessionUUID: Option[String], profiler: Option[String]) extends RowInputArrowPythonRunner( - funcs, evalType, argMetas.map(_.map(_.offset)), _schema, _timeZoneId, largeVarTypes, workerConf, + funcs, evalType, argMetas.map(_.map(_.offset)), _schema, _timeZoneId, largeVarTypes, pythonMetrics, jobArtifactUUID, sessionUUID) { override protected def writeUDF(dataOut: DataOutputStream): Unit = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonUDTFRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonUDTFRunner.scala index 1d5df9bad924..979d91205d5a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonUDTFRunner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonUDTFRunner.scala @@ -39,7 +39,7 @@ class ArrowPythonUDTFRunner( protected override val schema: StructType, protected override val timeZoneId: String, protected override val largeVarTypes: Boolean, - protected override val workerConf: Map[String, String], + protected override val runnerConf: Map[String, String], override val pythonMetrics: Map[String, SQLMetric], jobArtifactUUID: Option[String], sessionUUID: Option[String]) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/CoGroupedArrowPythonRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/CoGroupedArrowPythonRunner.scala index 7f6efbae8881..b5986be9214a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/CoGroupedArrowPythonRunner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/CoGroupedArrowPythonRunner.scala @@ -25,7 +25,7 @@ import org.apache.arrow.vector.{VectorSchemaRoot, VectorUnloader} import org.apache.arrow.vector.compression.{CompressionCodec, NoCompressionCodec} import org.apache.spark.{SparkEnv, SparkException, TaskContext} -import org.apache.spark.api.python.{BasePythonRunner, ChainedPythonFunctions, PythonRDD, PythonWorker} +import org.apache.spark.api.python.{BasePythonRunner, ChainedPythonFunctions, PythonWorker} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.execution.arrow.ArrowWriterWrapper import org.apache.spark.sql.execution.metric.SQLMetric @@ -45,7 +45,7 @@ class CoGroupedArrowPythonRunner( rightSchema: StructType, timeZoneId: String, largeVarTypes: Boolean, - conf: Map[String, String], + protected override val runnerConf: Map[String, String], override val pythonMetrics: Map[String, SQLMetric], jobArtifactUUID: Option[String], sessionUUID: Option[String], @@ -119,14 +119,6 @@ class CoGroupedArrowPythonRunner( private var rightGroupArrowWriter: ArrowWriterWrapper = null protected override def writeCommand(dataOut: DataOutputStream): Unit = { - - // Write config for the worker as a number of key -> value pairs of strings - dataOut.writeInt(conf.size) - for ((k, v) <- conf) { - PythonRDD.writeUTF(k, dataOut) - PythonRDD.writeUTF(v, dataOut) - } - PythonUDFRunner.writeUDFs(dataOut, funcs, argOffsets, profiler) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowInput.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowInput.scala index f77b0a9342b0..d2d16b0c9623 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowInput.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowInput.scala @@ -27,7 +27,7 @@ import org.apache.arrow.vector.ipc.WriteChannel import org.apache.arrow.vector.ipc.message.MessageSerializer import org.apache.spark.{SparkEnv, SparkException, TaskContext} -import org.apache.spark.api.python.{BasePythonRunner, PythonRDD, PythonWorker} +import org.apache.spark.api.python.{BasePythonRunner, PythonWorker} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.execution.arrow import org.apache.spark.sql.execution.arrow.{ArrowWriter, ArrowWriterWrapper} @@ -42,8 +42,6 @@ import org.apache.spark.util.Utils * JVM (an iterator of internal rows + additional data if required) to Python (Arrow). */ private[python] trait PythonArrowInput[IN] { self: BasePythonRunner[IN, _] => - protected val workerConf: Map[String, String] - protected val schema: StructType protected val timeZoneId: String @@ -62,14 +60,8 @@ private[python] trait PythonArrowInput[IN] { self: BasePythonRunner[IN, _] => protected def writeUDF(dataOut: DataOutputStream): Unit - protected def handleMetadataBeforeExec(stream: DataOutputStream): Unit = { - // Write config for the worker as a number of key -> value pairs of strings - stream.writeInt(workerConf.size) - for ((k, v) <- workerConf) { - PythonRDD.writeUTF(k, stream) - PythonRDD.writeUTF(v, stream) - } - } + protected def handleMetadataBeforeExec(stream: DataOutputStream): Unit = {} + private val arrowSchema = ArrowUtils.toArrowSchema( schema, timeZoneId, errorOnDuplicatedFieldNames, largeVarTypes) protected val allocator = @@ -301,7 +293,6 @@ private[python] trait GroupedPythonArrowInput { self: RowInputArrowPythonRunner context: TaskContext): Writer = { new Writer(env, worker, inputIterator, partitionIndex, context) { protected override def writeCommand(dataOut: DataOutputStream): Unit = { - handleMetadataBeforeExec(dataOut) writeUDF(dataOut) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/ApplyInPandasWithStatePythonRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/ApplyInPandasWithStatePythonRunner.scala index 14054ba89a94..ae89ff1637ed 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/ApplyInPandasWithStatePythonRunner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/ApplyInPandasWithStatePythonRunner.scala @@ -58,7 +58,7 @@ class ApplyInPandasWithStatePythonRunner( argOffsets: Array[Array[Int]], inputSchema: StructType, _timeZoneId: String, - initialWorkerConf: Map[String, String], + initialRunnerConf: Map[String, String], stateEncoder: ExpressionEncoder[Row], keySchema: StructType, outputSchema: StructType, @@ -113,7 +113,7 @@ class ApplyInPandasWithStatePythonRunner( // applyInPandasWithState has its own mechanism to construct the Arrow RecordBatch instance. // Configurations are both applied to executor and Python worker, set them to the worker conf // to let Python worker read the config properly. - override protected val workerConf: Map[String, String] = initialWorkerConf + + override protected val runnerConf: Map[String, String] = initialRunnerConf + (SQLConf.ARROW_EXECUTION_MAX_RECORDS_PER_BATCH.key -> arrowMaxRecordsPerBatch.toString) + (SQLConf.ARROW_EXECUTION_MAX_BYTES_PER_BATCH.key -> arrowMaxBytesPerBatch.toString) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPySparkPythonRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPySparkPythonRunner.scala index 42d4ad68c29a..10f86ce21d12 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPySparkPythonRunner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPySparkPythonRunner.scala @@ -52,7 +52,7 @@ class TransformWithStateInPySparkPythonRunner( _schema: StructType, processorHandle: StatefulProcessorHandleImpl, _timeZoneId: String, - initialWorkerConf: Map[String, String], + initialRunnerConf: Map[String, String], override val pythonMetrics: Map[String, SQLMetric], jobArtifactUUID: Option[String], groupingKeySchema: StructType, @@ -60,7 +60,7 @@ class TransformWithStateInPySparkPythonRunner( eventTimeWatermarkForEviction: Option[Long]) extends TransformWithStateInPySparkPythonBaseRunner[InType]( funcs, evalType, argOffsets, _schema, processorHandle, _timeZoneId, - initialWorkerConf, pythonMetrics, jobArtifactUUID, groupingKeySchema, + initialRunnerConf, pythonMetrics, jobArtifactUUID, groupingKeySchema, batchTimestampMs, eventTimeWatermarkForEviction) with PythonArrowInput[InType] { @@ -126,7 +126,7 @@ class TransformWithStateInPySparkPythonInitialStateRunner( initStateSchema: StructType, processorHandle: StatefulProcessorHandleImpl, _timeZoneId: String, - initialWorkerConf: Map[String, String], + initialRunnerConf: Map[String, String], override val pythonMetrics: Map[String, SQLMetric], jobArtifactUUID: Option[String], groupingKeySchema: StructType, @@ -134,7 +134,7 @@ class TransformWithStateInPySparkPythonInitialStateRunner( eventTimeWatermarkForEviction: Option[Long]) extends TransformWithStateInPySparkPythonBaseRunner[GroupedInType]( funcs, evalType, argOffsets, dataSchema, processorHandle, _timeZoneId, - initialWorkerConf, pythonMetrics, jobArtifactUUID, groupingKeySchema, + initialRunnerConf, pythonMetrics, jobArtifactUUID, groupingKeySchema, batchTimestampMs, eventTimeWatermarkForEviction) with PythonArrowInput[GroupedInType] { @@ -195,7 +195,7 @@ abstract class TransformWithStateInPySparkPythonBaseRunner[I]( _schema: StructType, processorHandle: StatefulProcessorHandleImpl, _timeZoneId: String, - initialWorkerConf: Map[String, String], + initialRunnerConf: Map[String, String], override val pythonMetrics: Map[String, SQLMetric], jobArtifactUUID: Option[String], groupingKeySchema: StructType, @@ -212,7 +212,7 @@ abstract class TransformWithStateInPySparkPythonBaseRunner[I]( protected val arrowMaxRecordsPerBatch = sqlConf.arrowMaxRecordsPerBatch protected val arrowMaxBytesPerBatch = sqlConf.arrowMaxBytesPerBatch - override protected val workerConf: Map[String, String] = initialWorkerConf + + override protected val runnerConf: Map[String, String] = initialRunnerConf + (SQLConf.ARROW_EXECUTION_MAX_RECORDS_PER_BATCH.key -> arrowMaxRecordsPerBatch.toString) + (SQLConf.ARROW_EXECUTION_MAX_BYTES_PER_BATCH.key -> arrowMaxBytesPerBatch.toString) @@ -225,7 +225,7 @@ abstract class TransformWithStateInPySparkPythonBaseRunner[I]( override protected def handleMetadataBeforeExec(stream: DataOutputStream): Unit = { super.handleMetadataBeforeExec(stream) - // Also write the port/path number for state server + // Write the port/path number for state server if (isUnixDomainSock) { stream.writeInt(-1) PythonWorkerUtils.writeUTF(stateServerSocketPath, stream)