Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,8 @@ private[spark] abstract class BasePythonRunner[IN, OUT](
protected val hideTraceback: Boolean = false
protected val simplifiedTraceback: Boolean = false

protected val runnerConf: Map[String, String] = Map.empty

// All the Python functions should have the same exec, version and envvars.
protected val envVars: java.util.Map[String, String] = funcs.head.funcs.head.envVars
protected val pythonExec: String = funcs.head.funcs.head.pythonExec
Expand Down Expand Up @@ -403,6 +405,17 @@ private[spark] abstract class BasePythonRunner[IN, OUT](
*/
protected def writeCommand(dataOut: DataOutputStream): Unit

/**
* Writes worker configuration to the stream connected to the Python worker.
*/
protected def writeRunnerConf(dataOut: DataOutputStream): Unit = {
dataOut.writeInt(runnerConf.size)
for ((k, v) <- runnerConf) {
PythonWorkerUtils.writeUTF(k, dataOut)
PythonWorkerUtils.writeUTF(v, dataOut)
}
}

/**
* Writes input data to the stream connected to the Python worker.
* Returns true if any data was written to the stream, false if the input is exhausted.
Expand Down Expand Up @@ -532,6 +545,7 @@ private[spark] abstract class BasePythonRunner[IN, OUT](
PythonWorkerUtils.writeBroadcasts(broadcastVars, worker, env, dataOut)

dataOut.writeInt(evalType)
writeRunnerConf(dataOut)
writeCommand(dataOut)

dataOut.flush()
Expand Down
21 changes: 9 additions & 12 deletions python/pyspark/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -1512,10 +1512,8 @@ def read_single_udf(pickleSer, infile, eval_type, runner_conf, udf_index, profil
# It expects the UDTF to be in a specific format and performs various checks to
# ensure the UDTF is valid. This function also prepares a mapper function for applying
# the UDTF logic to input rows.
def read_udtf(pickleSer, infile, eval_type):
def read_udtf(pickleSer, infile, eval_type, runner_conf):
if eval_type == PythonEvalType.SQL_ARROW_TABLE_UDF:
# Load conf used for arrow evaluation.
runner_conf = RunnerConf(infile)
input_types = [
field.dataType for field in _parse_datatype_json_string(utf8_deserializer.loads(infile))
]
Expand All @@ -1530,15 +1528,13 @@ def read_udtf(pickleSer, infile, eval_type):
else:
ser = ArrowStreamUDTFSerializer()
elif eval_type == PythonEvalType.SQL_ARROW_UDTF:
runner_conf = RunnerConf(infile)
# Read the table argument offsets
num_table_arg_offsets = read_int(infile)
table_arg_offsets = [read_int(infile) for _ in range(num_table_arg_offsets)]
# Use PyArrow-native serializer for Arrow UDTFs with potential UDT support
ser = ArrowStreamArrowUDTFSerializer(table_arg_offsets=table_arg_offsets)
else:
# Each row is a group so do not batch but send one by one.
runner_conf = RunnerConf()
ser = BatchedSerializer(CPickleSerializer(), 1)

# See 'PythonUDTFRunner.PythonUDFWriterThread.writeCommand'
Expand Down Expand Up @@ -2686,7 +2682,7 @@ def mapper(_, it):
return mapper, None, ser, ser


def read_udfs(pickleSer, infile, eval_type):
def read_udfs(pickleSer, infile, eval_type, runner_conf):
state_server_port = None
key_schema = None
if eval_type in (
Expand Down Expand Up @@ -2714,9 +2710,6 @@ def read_udfs(pickleSer, infile, eval_type):
PythonEvalType.SQL_TRANSFORM_WITH_STATE_PYTHON_ROW_UDF,
PythonEvalType.SQL_TRANSFORM_WITH_STATE_PYTHON_ROW_INIT_STATE_UDF,
):
# Load conf used for pandas_udf evaluation
runner_conf = RunnerConf(infile)

state_object_schema = None
if eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE:
state_object_schema = StructType.fromJson(json.loads(utf8_deserializer.loads(infile)))
Expand Down Expand Up @@ -2868,7 +2861,6 @@ def read_udfs(pickleSer, infile, eval_type):
int_to_decimal_coercion_enabled=runner_conf.int_to_decimal_coercion_enabled,
)
else:
runner_conf = RunnerConf()
batch_size = int(os.environ.get("PYTHON_UDF_BATCH_SIZE", "100"))
ser = BatchedSerializer(CPickleSerializer(), batch_size)

Expand Down Expand Up @@ -3351,16 +3343,21 @@ def main(infile, outfile):

_accumulatorRegistry.clear()
eval_type = read_int(infile)
runner_conf = RunnerConf(infile)
if eval_type == PythonEvalType.NON_UDF:
func, profiler, deserializer, serializer = read_command(pickleSer, infile)
elif eval_type in (
PythonEvalType.SQL_TABLE_UDF,
PythonEvalType.SQL_ARROW_TABLE_UDF,
PythonEvalType.SQL_ARROW_UDTF,
):
func, profiler, deserializer, serializer = read_udtf(pickleSer, infile, eval_type)
func, profiler, deserializer, serializer = read_udtf(
pickleSer, infile, eval_type, runner_conf
)
else:
func, profiler, deserializer, serializer = read_udfs(pickleSer, infile, eval_type)
func, profiler, deserializer, serializer = read_udfs(
pickleSer, infile, eval_type, runner_conf
)

init_time = time.time()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ abstract class BaseArrowPythonRunner[IN, OUT <: AnyRef](
_schema: StructType,
_timeZoneId: String,
protected override val largeVarTypes: Boolean,
protected override val workerConf: Map[String, String],
override val pythonMetrics: Map[String, SQLMetric],
jobArtifactUUID: Option[String],
sessionUUID: Option[String])
Expand Down Expand Up @@ -86,12 +85,11 @@ abstract class RowInputArrowPythonRunner(
_schema: StructType,
_timeZoneId: String,
largeVarTypes: Boolean,
workerConf: Map[String, String],
pythonMetrics: Map[String, SQLMetric],
jobArtifactUUID: Option[String],
sessionUUID: Option[String])
extends BaseArrowPythonRunner[Iterator[InternalRow], ColumnarBatch](
funcs, evalType, argOffsets, _schema, _timeZoneId, largeVarTypes, workerConf,
funcs, evalType, argOffsets, _schema, _timeZoneId, largeVarTypes,
pythonMetrics, jobArtifactUUID, sessionUUID)
with BasicPythonArrowInput
with BasicPythonArrowOutput
Expand All @@ -106,13 +104,13 @@ class ArrowPythonRunner(
_schema: StructType,
_timeZoneId: String,
largeVarTypes: Boolean,
workerConf: Map[String, String],
protected override val runnerConf: Map[String, String],
pythonMetrics: Map[String, SQLMetric],
jobArtifactUUID: Option[String],
sessionUUID: Option[String],
profiler: Option[String])
extends RowInputArrowPythonRunner(
funcs, evalType, argOffsets, _schema, _timeZoneId, largeVarTypes, workerConf,
funcs, evalType, argOffsets, _schema, _timeZoneId, largeVarTypes,
pythonMetrics, jobArtifactUUID, sessionUUID) {

override protected def writeUDF(dataOut: DataOutputStream): Unit =
Expand All @@ -130,13 +128,13 @@ class ArrowPythonWithNamedArgumentRunner(
_schema: StructType,
_timeZoneId: String,
largeVarTypes: Boolean,
workerConf: Map[String, String],
protected override val runnerConf: Map[String, String],
pythonMetrics: Map[String, SQLMetric],
jobArtifactUUID: Option[String],
sessionUUID: Option[String],
profiler: Option[String])
extends RowInputArrowPythonRunner(
funcs, evalType, argMetas.map(_.map(_.offset)), _schema, _timeZoneId, largeVarTypes, workerConf,
funcs, evalType, argMetas.map(_.map(_.offset)), _schema, _timeZoneId, largeVarTypes,
pythonMetrics, jobArtifactUUID, sessionUUID) {

override protected def writeUDF(dataOut: DataOutputStream): Unit = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ class ArrowPythonUDTFRunner(
protected override val schema: StructType,
protected override val timeZoneId: String,
protected override val largeVarTypes: Boolean,
protected override val workerConf: Map[String, String],
protected override val runnerConf: Map[String, String],
override val pythonMetrics: Map[String, SQLMetric],
jobArtifactUUID: Option[String],
sessionUUID: Option[String])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import org.apache.arrow.vector.{VectorSchemaRoot, VectorUnloader}
import org.apache.arrow.vector.compression.{CompressionCodec, NoCompressionCodec}

import org.apache.spark.{SparkEnv, SparkException, TaskContext}
import org.apache.spark.api.python.{BasePythonRunner, ChainedPythonFunctions, PythonRDD, PythonWorker}
import org.apache.spark.api.python.{BasePythonRunner, ChainedPythonFunctions, PythonWorker}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.execution.arrow.ArrowWriterWrapper
import org.apache.spark.sql.execution.metric.SQLMetric
Expand All @@ -45,7 +45,7 @@ class CoGroupedArrowPythonRunner(
rightSchema: StructType,
timeZoneId: String,
largeVarTypes: Boolean,
conf: Map[String, String],
protected override val runnerConf: Map[String, String],
override val pythonMetrics: Map[String, SQLMetric],
jobArtifactUUID: Option[String],
sessionUUID: Option[String],
Expand Down Expand Up @@ -119,14 +119,6 @@ class CoGroupedArrowPythonRunner(
private var rightGroupArrowWriter: ArrowWriterWrapper = null

protected override def writeCommand(dataOut: DataOutputStream): Unit = {

// Write config for the worker as a number of key -> value pairs of strings
dataOut.writeInt(conf.size)
for ((k, v) <- conf) {
PythonRDD.writeUTF(k, dataOut)
PythonRDD.writeUTF(v, dataOut)
}

PythonUDFRunner.writeUDFs(dataOut, funcs, argOffsets, profiler)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import org.apache.arrow.vector.ipc.WriteChannel
import org.apache.arrow.vector.ipc.message.MessageSerializer

import org.apache.spark.{SparkEnv, SparkException, TaskContext}
import org.apache.spark.api.python.{BasePythonRunner, PythonRDD, PythonWorker}
import org.apache.spark.api.python.{BasePythonRunner, PythonWorker}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.execution.arrow
import org.apache.spark.sql.execution.arrow.{ArrowWriter, ArrowWriterWrapper}
Expand All @@ -42,8 +42,6 @@ import org.apache.spark.util.Utils
* JVM (an iterator of internal rows + additional data if required) to Python (Arrow).
*/
private[python] trait PythonArrowInput[IN] { self: BasePythonRunner[IN, _] =>
protected val workerConf: Map[String, String]

protected val schema: StructType

protected val timeZoneId: String
Expand All @@ -62,14 +60,8 @@ private[python] trait PythonArrowInput[IN] { self: BasePythonRunner[IN, _] =>

protected def writeUDF(dataOut: DataOutputStream): Unit

protected def handleMetadataBeforeExec(stream: DataOutputStream): Unit = {
// Write config for the worker as a number of key -> value pairs of strings
stream.writeInt(workerConf.size)
for ((k, v) <- workerConf) {
PythonRDD.writeUTF(k, stream)
PythonRDD.writeUTF(v, stream)
}
}
protected def handleMetadataBeforeExec(stream: DataOutputStream): Unit = {}

private val arrowSchema = ArrowUtils.toArrowSchema(
schema, timeZoneId, errorOnDuplicatedFieldNames, largeVarTypes)
protected val allocator =
Expand Down Expand Up @@ -301,7 +293,6 @@ private[python] trait GroupedPythonArrowInput { self: RowInputArrowPythonRunner
context: TaskContext): Writer = {
new Writer(env, worker, inputIterator, partitionIndex, context) {
protected override def writeCommand(dataOut: DataOutputStream): Unit = {
handleMetadataBeforeExec(dataOut)
writeUDF(dataOut)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ class ApplyInPandasWithStatePythonRunner(
argOffsets: Array[Array[Int]],
inputSchema: StructType,
_timeZoneId: String,
initialWorkerConf: Map[String, String],
initialRunnerConf: Map[String, String],
stateEncoder: ExpressionEncoder[Row],
keySchema: StructType,
outputSchema: StructType,
Expand Down Expand Up @@ -113,7 +113,7 @@ class ApplyInPandasWithStatePythonRunner(
// applyInPandasWithState has its own mechanism to construct the Arrow RecordBatch instance.
// Configurations are both applied to executor and Python worker, set them to the worker conf
// to let Python worker read the config properly.
override protected val workerConf: Map[String, String] = initialWorkerConf +
override protected val runnerConf: Map[String, String] = initialRunnerConf +
(SQLConf.ARROW_EXECUTION_MAX_RECORDS_PER_BATCH.key -> arrowMaxRecordsPerBatch.toString) +
(SQLConf.ARROW_EXECUTION_MAX_BYTES_PER_BATCH.key -> arrowMaxBytesPerBatch.toString)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,15 +52,15 @@ class TransformWithStateInPySparkPythonRunner(
_schema: StructType,
processorHandle: StatefulProcessorHandleImpl,
_timeZoneId: String,
initialWorkerConf: Map[String, String],
initialRunnerConf: Map[String, String],
override val pythonMetrics: Map[String, SQLMetric],
jobArtifactUUID: Option[String],
groupingKeySchema: StructType,
batchTimestampMs: Option[Long],
eventTimeWatermarkForEviction: Option[Long])
extends TransformWithStateInPySparkPythonBaseRunner[InType](
funcs, evalType, argOffsets, _schema, processorHandle, _timeZoneId,
initialWorkerConf, pythonMetrics, jobArtifactUUID, groupingKeySchema,
initialRunnerConf, pythonMetrics, jobArtifactUUID, groupingKeySchema,
batchTimestampMs, eventTimeWatermarkForEviction)
with PythonArrowInput[InType] {

Expand Down Expand Up @@ -126,15 +126,15 @@ class TransformWithStateInPySparkPythonInitialStateRunner(
initStateSchema: StructType,
processorHandle: StatefulProcessorHandleImpl,
_timeZoneId: String,
initialWorkerConf: Map[String, String],
initialRunnerConf: Map[String, String],
override val pythonMetrics: Map[String, SQLMetric],
jobArtifactUUID: Option[String],
groupingKeySchema: StructType,
batchTimestampMs: Option[Long],
eventTimeWatermarkForEviction: Option[Long])
extends TransformWithStateInPySparkPythonBaseRunner[GroupedInType](
funcs, evalType, argOffsets, dataSchema, processorHandle, _timeZoneId,
initialWorkerConf, pythonMetrics, jobArtifactUUID, groupingKeySchema,
initialRunnerConf, pythonMetrics, jobArtifactUUID, groupingKeySchema,
batchTimestampMs, eventTimeWatermarkForEviction)
with PythonArrowInput[GroupedInType] {

Expand Down Expand Up @@ -195,7 +195,7 @@ abstract class TransformWithStateInPySparkPythonBaseRunner[I](
_schema: StructType,
processorHandle: StatefulProcessorHandleImpl,
_timeZoneId: String,
initialWorkerConf: Map[String, String],
initialRunnerConf: Map[String, String],
override val pythonMetrics: Map[String, SQLMetric],
jobArtifactUUID: Option[String],
groupingKeySchema: StructType,
Expand All @@ -212,7 +212,7 @@ abstract class TransformWithStateInPySparkPythonBaseRunner[I](
protected val arrowMaxRecordsPerBatch = sqlConf.arrowMaxRecordsPerBatch
protected val arrowMaxBytesPerBatch = sqlConf.arrowMaxBytesPerBatch

override protected val workerConf: Map[String, String] = initialWorkerConf +
override protected val runnerConf: Map[String, String] = initialRunnerConf +
(SQLConf.ARROW_EXECUTION_MAX_RECORDS_PER_BATCH.key -> arrowMaxRecordsPerBatch.toString) +
(SQLConf.ARROW_EXECUTION_MAX_BYTES_PER_BATCH.key -> arrowMaxBytesPerBatch.toString)

Expand All @@ -225,7 +225,7 @@ abstract class TransformWithStateInPySparkPythonBaseRunner[I](

override protected def handleMetadataBeforeExec(stream: DataOutputStream): Unit = {
super.handleMetadataBeforeExec(stream)
// Also write the port/path number for state server
// Write the port/path number for state server
if (isUnixDomainSock) {
stream.writeInt(-1)
PythonWorkerUtils.writeUTF(stateServerSocketPath, stream)
Expand Down