Skip to content

Commit 6e41ac5

Browse files
ueshindongjoon-hyun
authored andcommitted
[SPARK-53978][PYTHON] Support logging in driver-side workers
### What changes were proposed in this pull request? Supports logging in driver-side workers. ### Why are the changes needed? The basic logging infrastructure was introduced in #52689, and the driver-side workers should also support logging. Here adding support for driver-side workers. ### Does this PR introduce _any_ user-facing change? Yes, the logging feature will be available in driver-side workers. ### How was this patch tested? Added the related tests. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #52808 from ueshin/issues/SPARK-53978/driverside. Authored-by: Takuya Ueshin <ueshin@databricks.com> Signed-off-by: Dongjoon Hyun <dongjoon@apache.org>
1 parent 7624476 commit 6e41ac5

18 files changed

+717
-251
lines changed

python/pyspark/sql/tests/test_python_datasource.py

Lines changed: 356 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,10 @@
1919
import sys
2020
import tempfile
2121
import unittest
22+
import logging
23+
import json
24+
import os
25+
from dataclasses import dataclass
2226
from datetime import datetime
2327
from decimal import Decimal
2428
from typing import Callable, Iterable, List, Union, Iterator, Tuple
@@ -57,6 +61,7 @@
5761
have_pyarrow,
5862
pyarrow_requirement_message,
5963
)
64+
from pyspark.util import is_remote_only
6065

6166

6267
@unittest.skipIf(not have_pyarrow, pyarrow_requirement_message)
@@ -907,6 +912,357 @@ def commit(self, messages):
907912
"test_table"
908913
)
909914

915+
@unittest.skipIf(is_remote_only(), "Requires JVM access")
916+
def test_data_source_reader_with_logging(self):
917+
logger = logging.getLogger("test_data_source_reader")
918+
919+
class TestJsonReader(DataSourceReader):
920+
def __init__(self, options):
921+
logger.warning(f"TestJsonReader.__init__: {list(options)}")
922+
self.options = options
923+
924+
def partitions(self):
925+
logger.warning("TestJsonReader.partitions")
926+
return super().partitions()
927+
928+
def read(self, partition):
929+
logger.warning(f"TestJsonReader.read: {partition}")
930+
path = self.options.get("path")
931+
if path is None:
932+
raise Exception("path is not specified")
933+
with open(path, "r") as file:
934+
for line in file.readlines():
935+
if line.strip():
936+
data = json.loads(line)
937+
yield data.get("name"), data.get("age")
938+
939+
class TestJsonDataSource(DataSource):
940+
def __init__(self, options):
941+
super().__init__(options)
942+
logger.warning(f"TestJsonDataSource.__init__: {list(options)}")
943+
944+
@classmethod
945+
def name(cls):
946+
logger.warning("TestJsonDataSource.name")
947+
return "my-json"
948+
949+
def schema(self):
950+
logger.warning("TestJsonDataSource.schema")
951+
return "name STRING, age INT"
952+
953+
def reader(self, schema) -> "DataSourceReader":
954+
logger.warning(f"TestJsonDataSource.reader: {schema.fieldNames()}")
955+
return TestJsonReader(self.options)
956+
957+
self.spark.dataSource.register(TestJsonDataSource)
958+
path1 = os.path.join(SPARK_HOME, "python/test_support/sql/people.json")
959+
960+
with self.sql_conf({"spark.sql.pyspark.worker.logging.enabled": "true"}):
961+
assertDataFrameEqual(
962+
self.spark.read.format("my-json").load(path1),
963+
[
964+
Row(name="Michael", age=None),
965+
Row(name="Andy", age=30),
966+
Row(name="Justin", age=19),
967+
],
968+
)
969+
970+
logs = self.spark.table("system.session.python_worker_logs")
971+
972+
assertDataFrameEqual(
973+
logs.select("level", "msg", "context", "logger"),
974+
[
975+
Row(
976+
level="WARNING",
977+
msg=msg,
978+
context=context,
979+
logger="test_data_source_reader",
980+
)
981+
for msg, context in [
982+
(
983+
"TestJsonDataSource.__init__: ['path']",
984+
{"class_name": "TestJsonDataSource", "func_name": "__init__"},
985+
),
986+
(
987+
"TestJsonDataSource.name",
988+
{"class_name": "TestJsonDataSource", "func_name": "name"},
989+
),
990+
(
991+
"TestJsonDataSource.schema",
992+
{"class_name": "TestJsonDataSource", "func_name": "schema"},
993+
),
994+
(
995+
"TestJsonDataSource.reader: ['name', 'age']",
996+
{"class_name": "TestJsonDataSource", "func_name": "reader"},
997+
),
998+
(
999+
"TestJsonReader.__init__: ['path']",
1000+
{"class_name": "TestJsonDataSource", "func_name": "reader"},
1001+
),
1002+
(
1003+
"TestJsonReader.partitions",
1004+
{"class_name": "TestJsonReader", "func_name": "partitions"},
1005+
),
1006+
(
1007+
"TestJsonReader.read: None",
1008+
{"class_name": "TestJsonReader", "func_name": "read"},
1009+
),
1010+
]
1011+
],
1012+
)
1013+
1014+
@unittest.skipIf(is_remote_only(), "Requires JVM access")
1015+
def test_data_source_reader_pushdown_with_logging(self):
1016+
logger = logging.getLogger("test_data_source_reader_pushdown")
1017+
1018+
class TestJsonReader(DataSourceReader):
1019+
def __init__(self, options):
1020+
logger.warning(f"TestJsonReader.__init__: {list(options)}")
1021+
self.options = options
1022+
1023+
def pushFilters(self, filters):
1024+
logger.warning(f"TestJsonReader.pushFilters: {filters}")
1025+
return super().pushFilters(filters)
1026+
1027+
def partitions(self):
1028+
logger.warning("TestJsonReader.partitions")
1029+
return super().partitions()
1030+
1031+
def read(self, partition):
1032+
logger.warning(f"TestJsonReader.read: {partition}")
1033+
path = self.options.get("path")
1034+
if path is None:
1035+
raise Exception("path is not specified")
1036+
with open(path, "r") as file:
1037+
for line in file.readlines():
1038+
if line.strip():
1039+
data = json.loads(line)
1040+
yield data.get("name"), data.get("age")
1041+
1042+
class TestJsonDataSource(DataSource):
1043+
def __init__(self, options):
1044+
super().__init__(options)
1045+
logger.warning(f"TestJsonDataSource.__init__: {list(options)}")
1046+
1047+
@classmethod
1048+
def name(cls):
1049+
logger.warning("TestJsonDataSource.name")
1050+
return "my-json"
1051+
1052+
def schema(self):
1053+
logger.warning("TestJsonDataSource.schema")
1054+
return "name STRING, age INT"
1055+
1056+
def reader(self, schema) -> "DataSourceReader":
1057+
logger.warning(f"TestJsonDataSource.reader: {schema.fieldNames()}")
1058+
return TestJsonReader(self.options)
1059+
1060+
self.spark.dataSource.register(TestJsonDataSource)
1061+
path1 = os.path.join(SPARK_HOME, "python/test_support/sql/people.json")
1062+
1063+
with self.sql_conf(
1064+
{
1065+
"spark.sql.python.filterPushdown.enabled": "true",
1066+
"spark.sql.pyspark.worker.logging.enabled": "true",
1067+
}
1068+
):
1069+
assertDataFrameEqual(
1070+
self.spark.read.format("my-json").load(path1).filter("age is not null"),
1071+
[
1072+
Row(name="Andy", age=30),
1073+
Row(name="Justin", age=19),
1074+
],
1075+
)
1076+
1077+
logs = self.spark.table("system.session.python_worker_logs")
1078+
1079+
assertDataFrameEqual(
1080+
logs.select("level", "msg", "context", "logger"),
1081+
[
1082+
Row(
1083+
level="WARNING",
1084+
msg=msg,
1085+
context=context,
1086+
logger="test_data_source_reader_pushdown",
1087+
)
1088+
for msg, context in [
1089+
(
1090+
"TestJsonDataSource.__init__: ['path']",
1091+
{"class_name": "TestJsonDataSource", "func_name": "__init__"},
1092+
),
1093+
(
1094+
"TestJsonDataSource.name",
1095+
{"class_name": "TestJsonDataSource", "func_name": "name"},
1096+
),
1097+
(
1098+
"TestJsonDataSource.schema",
1099+
{"class_name": "TestJsonDataSource", "func_name": "schema"},
1100+
),
1101+
(
1102+
"TestJsonDataSource.reader: ['name', 'age']",
1103+
{"class_name": "TestJsonDataSource", "func_name": "reader"},
1104+
),
1105+
(
1106+
"TestJsonReader.pushFilters: [IsNotNull(attribute=('age',))]",
1107+
{"class_name": "TestJsonReader", "func_name": "pushFilters"},
1108+
),
1109+
(
1110+
"TestJsonReader.__init__: ['path']",
1111+
{"class_name": "TestJsonDataSource", "func_name": "reader"},
1112+
),
1113+
(
1114+
"TestJsonReader.partitions",
1115+
{"class_name": "TestJsonReader", "func_name": "partitions"},
1116+
),
1117+
(
1118+
"TestJsonReader.read: None",
1119+
{"class_name": "TestJsonReader", "func_name": "read"},
1120+
),
1121+
]
1122+
],
1123+
)
1124+
1125+
@unittest.skipIf(is_remote_only(), "Requires JVM access")
1126+
def test_data_source_writer_with_logging(self):
1127+
logger = logging.getLogger("test_datasource_writer")
1128+
1129+
@dataclass
1130+
class TestCommitMessage(WriterCommitMessage):
1131+
count: int
1132+
1133+
class TestJsonWriter(DataSourceWriter):
1134+
def __init__(self, options):
1135+
logger.warning(f"TestJsonWriter.__init__: {list(options)}")
1136+
self.options = options
1137+
self.path = self.options.get("path")
1138+
1139+
def write(self, iterator):
1140+
from pyspark import TaskContext
1141+
1142+
if self.options.get("abort", None):
1143+
logger.warning("TestJsonWriter.write: abort test")
1144+
raise Exception("abort test")
1145+
1146+
context = TaskContext.get()
1147+
output_path = os.path.join(self.path, f"{context.partitionId()}.json")
1148+
count = 0
1149+
rows = []
1150+
with open(output_path, "w") as file:
1151+
for row in iterator:
1152+
count += 1
1153+
rows.append(row.asDict())
1154+
file.write(json.dumps(row.asDict()) + "\n")
1155+
1156+
logger.warning(f"TestJsonWriter.write: {count}, {rows}")
1157+
1158+
return TestCommitMessage(count=count)
1159+
1160+
def commit(self, messages):
1161+
total_count = sum(message.count for message in messages)
1162+
with open(os.path.join(self.path, "_success.txt"), "w") as file:
1163+
file.write(f"count: {total_count}\n")
1164+
1165+
logger.warning(f"TestJsonWriter.commit: {total_count}")
1166+
1167+
def abort(self, messages):
1168+
with open(os.path.join(self.path, "_failed.txt"), "w") as file:
1169+
file.write("failed")
1170+
1171+
logger.warning("TestJsonWriter.abort")
1172+
1173+
class TestJsonDataSource(DataSource):
1174+
@classmethod
1175+
def name(cls):
1176+
logger.warning("TestJsonDataSource.name")
1177+
return "my-json"
1178+
1179+
def writer(self, schema, overwrite):
1180+
logger.warning(f"TestJsonDataSource.writer: {schema.fieldNames(), {overwrite}}")
1181+
return TestJsonWriter(self.options)
1182+
1183+
# Register the data source
1184+
self.spark.dataSource.register(TestJsonDataSource)
1185+
1186+
with tempfile.TemporaryDirectory(prefix="test_datasource_write_logging") as d:
1187+
with self.sql_conf({"spark.sql.pyspark.worker.logging.enabled": "true"}):
1188+
# Create a simple DataFrame and write it using our custom datasource
1189+
df = self.spark.createDataFrame(
1190+
[("Charlie", 35), ("Diana", 28)], "name STRING, age INT"
1191+
).repartitionByRange(2, "age")
1192+
df.write.format("my-json").mode("overwrite").save(d)
1193+
1194+
# Verify the write worked by checking the success file
1195+
with open(os.path.join(d, "_success.txt"), "r") as file:
1196+
text = file.read()
1197+
self.assertEqual(text, "count: 2\n")
1198+
1199+
with self.assertRaises(Exception, msg="abort test"):
1200+
df.write.format("my-json").mode("append").option("abort", "true").save(d)
1201+
1202+
logs = self.spark.table("system.session.python_worker_logs")
1203+
1204+
assertDataFrameEqual(
1205+
logs.select("level", "msg", "context", "logger"),
1206+
[
1207+
Row(
1208+
level="WARNING",
1209+
msg=msg,
1210+
context=context,
1211+
logger="test_datasource_writer",
1212+
)
1213+
for msg, context in [
1214+
(
1215+
"TestJsonDataSource.name",
1216+
{"class_name": "TestJsonDataSource", "func_name": "name"},
1217+
),
1218+
(
1219+
"TestJsonDataSource.writer: (['name', 'age'], {True})",
1220+
{"class_name": "TestJsonDataSource", "func_name": "writer"},
1221+
),
1222+
(
1223+
"TestJsonWriter.__init__: ['path']",
1224+
{"class_name": "TestJsonDataSource", "func_name": "writer"},
1225+
),
1226+
(
1227+
"TestJsonWriter.write: 1, [{'name': 'Diana', 'age': 28}]",
1228+
{"class_name": "TestJsonWriter", "func_name": "write"},
1229+
),
1230+
(
1231+
"TestJsonWriter.write: 1, [{'name': 'Charlie', 'age': 35}]",
1232+
{"class_name": "TestJsonWriter", "func_name": "write"},
1233+
),
1234+
(
1235+
"TestJsonWriter.commit: 2",
1236+
{"class_name": "TestJsonWriter", "func_name": "commit"},
1237+
),
1238+
(
1239+
"TestJsonDataSource.name",
1240+
{"class_name": "TestJsonDataSource", "func_name": "name"},
1241+
),
1242+
(
1243+
"TestJsonDataSource.writer: (['name', 'age'], {False})",
1244+
{"class_name": "TestJsonDataSource", "func_name": "writer"},
1245+
),
1246+
(
1247+
"TestJsonWriter.__init__: ['abort', 'path']",
1248+
{"class_name": "TestJsonDataSource", "func_name": "writer"},
1249+
),
1250+
(
1251+
"TestJsonWriter.write: abort test",
1252+
{"class_name": "TestJsonWriter", "func_name": "write"},
1253+
),
1254+
(
1255+
"TestJsonWriter.write: abort test",
1256+
{"class_name": "TestJsonWriter", "func_name": "write"},
1257+
),
1258+
(
1259+
"TestJsonWriter.abort",
1260+
{"class_name": "TestJsonWriter", "func_name": "abort"},
1261+
),
1262+
]
1263+
],
1264+
)
1265+
9101266

9111267
class PythonDataSourceTests(BasePythonDataSourceTestsMixin, ReusedSQLTestCase):
9121268
...

0 commit comments

Comments
 (0)