|
19 | 19 | import sys |
20 | 20 | import tempfile |
21 | 21 | import unittest |
| 22 | +import logging |
| 23 | +import json |
| 24 | +import os |
| 25 | +from dataclasses import dataclass |
22 | 26 | from datetime import datetime |
23 | 27 | from decimal import Decimal |
24 | 28 | from typing import Callable, Iterable, List, Union, Iterator, Tuple |
|
57 | 61 | have_pyarrow, |
58 | 62 | pyarrow_requirement_message, |
59 | 63 | ) |
| 64 | +from pyspark.util import is_remote_only |
60 | 65 |
|
61 | 66 |
|
62 | 67 | @unittest.skipIf(not have_pyarrow, pyarrow_requirement_message) |
@@ -907,6 +912,357 @@ def commit(self, messages): |
907 | 912 | "test_table" |
908 | 913 | ) |
909 | 914 |
|
| 915 | + @unittest.skipIf(is_remote_only(), "Requires JVM access") |
| 916 | + def test_data_source_reader_with_logging(self): |
| 917 | + logger = logging.getLogger("test_data_source_reader") |
| 918 | + |
| 919 | + class TestJsonReader(DataSourceReader): |
| 920 | + def __init__(self, options): |
| 921 | + logger.warning(f"TestJsonReader.__init__: {list(options)}") |
| 922 | + self.options = options |
| 923 | + |
| 924 | + def partitions(self): |
| 925 | + logger.warning("TestJsonReader.partitions") |
| 926 | + return super().partitions() |
| 927 | + |
| 928 | + def read(self, partition): |
| 929 | + logger.warning(f"TestJsonReader.read: {partition}") |
| 930 | + path = self.options.get("path") |
| 931 | + if path is None: |
| 932 | + raise Exception("path is not specified") |
| 933 | + with open(path, "r") as file: |
| 934 | + for line in file.readlines(): |
| 935 | + if line.strip(): |
| 936 | + data = json.loads(line) |
| 937 | + yield data.get("name"), data.get("age") |
| 938 | + |
| 939 | + class TestJsonDataSource(DataSource): |
| 940 | + def __init__(self, options): |
| 941 | + super().__init__(options) |
| 942 | + logger.warning(f"TestJsonDataSource.__init__: {list(options)}") |
| 943 | + |
| 944 | + @classmethod |
| 945 | + def name(cls): |
| 946 | + logger.warning("TestJsonDataSource.name") |
| 947 | + return "my-json" |
| 948 | + |
| 949 | + def schema(self): |
| 950 | + logger.warning("TestJsonDataSource.schema") |
| 951 | + return "name STRING, age INT" |
| 952 | + |
| 953 | + def reader(self, schema) -> "DataSourceReader": |
| 954 | + logger.warning(f"TestJsonDataSource.reader: {schema.fieldNames()}") |
| 955 | + return TestJsonReader(self.options) |
| 956 | + |
| 957 | + self.spark.dataSource.register(TestJsonDataSource) |
| 958 | + path1 = os.path.join(SPARK_HOME, "python/test_support/sql/people.json") |
| 959 | + |
| 960 | + with self.sql_conf({"spark.sql.pyspark.worker.logging.enabled": "true"}): |
| 961 | + assertDataFrameEqual( |
| 962 | + self.spark.read.format("my-json").load(path1), |
| 963 | + [ |
| 964 | + Row(name="Michael", age=None), |
| 965 | + Row(name="Andy", age=30), |
| 966 | + Row(name="Justin", age=19), |
| 967 | + ], |
| 968 | + ) |
| 969 | + |
| 970 | + logs = self.spark.table("system.session.python_worker_logs") |
| 971 | + |
| 972 | + assertDataFrameEqual( |
| 973 | + logs.select("level", "msg", "context", "logger"), |
| 974 | + [ |
| 975 | + Row( |
| 976 | + level="WARNING", |
| 977 | + msg=msg, |
| 978 | + context=context, |
| 979 | + logger="test_data_source_reader", |
| 980 | + ) |
| 981 | + for msg, context in [ |
| 982 | + ( |
| 983 | + "TestJsonDataSource.__init__: ['path']", |
| 984 | + {"class_name": "TestJsonDataSource", "func_name": "__init__"}, |
| 985 | + ), |
| 986 | + ( |
| 987 | + "TestJsonDataSource.name", |
| 988 | + {"class_name": "TestJsonDataSource", "func_name": "name"}, |
| 989 | + ), |
| 990 | + ( |
| 991 | + "TestJsonDataSource.schema", |
| 992 | + {"class_name": "TestJsonDataSource", "func_name": "schema"}, |
| 993 | + ), |
| 994 | + ( |
| 995 | + "TestJsonDataSource.reader: ['name', 'age']", |
| 996 | + {"class_name": "TestJsonDataSource", "func_name": "reader"}, |
| 997 | + ), |
| 998 | + ( |
| 999 | + "TestJsonReader.__init__: ['path']", |
| 1000 | + {"class_name": "TestJsonDataSource", "func_name": "reader"}, |
| 1001 | + ), |
| 1002 | + ( |
| 1003 | + "TestJsonReader.partitions", |
| 1004 | + {"class_name": "TestJsonReader", "func_name": "partitions"}, |
| 1005 | + ), |
| 1006 | + ( |
| 1007 | + "TestJsonReader.read: None", |
| 1008 | + {"class_name": "TestJsonReader", "func_name": "read"}, |
| 1009 | + ), |
| 1010 | + ] |
| 1011 | + ], |
| 1012 | + ) |
| 1013 | + |
| 1014 | + @unittest.skipIf(is_remote_only(), "Requires JVM access") |
| 1015 | + def test_data_source_reader_pushdown_with_logging(self): |
| 1016 | + logger = logging.getLogger("test_data_source_reader_pushdown") |
| 1017 | + |
| 1018 | + class TestJsonReader(DataSourceReader): |
| 1019 | + def __init__(self, options): |
| 1020 | + logger.warning(f"TestJsonReader.__init__: {list(options)}") |
| 1021 | + self.options = options |
| 1022 | + |
| 1023 | + def pushFilters(self, filters): |
| 1024 | + logger.warning(f"TestJsonReader.pushFilters: {filters}") |
| 1025 | + return super().pushFilters(filters) |
| 1026 | + |
| 1027 | + def partitions(self): |
| 1028 | + logger.warning("TestJsonReader.partitions") |
| 1029 | + return super().partitions() |
| 1030 | + |
| 1031 | + def read(self, partition): |
| 1032 | + logger.warning(f"TestJsonReader.read: {partition}") |
| 1033 | + path = self.options.get("path") |
| 1034 | + if path is None: |
| 1035 | + raise Exception("path is not specified") |
| 1036 | + with open(path, "r") as file: |
| 1037 | + for line in file.readlines(): |
| 1038 | + if line.strip(): |
| 1039 | + data = json.loads(line) |
| 1040 | + yield data.get("name"), data.get("age") |
| 1041 | + |
| 1042 | + class TestJsonDataSource(DataSource): |
| 1043 | + def __init__(self, options): |
| 1044 | + super().__init__(options) |
| 1045 | + logger.warning(f"TestJsonDataSource.__init__: {list(options)}") |
| 1046 | + |
| 1047 | + @classmethod |
| 1048 | + def name(cls): |
| 1049 | + logger.warning("TestJsonDataSource.name") |
| 1050 | + return "my-json" |
| 1051 | + |
| 1052 | + def schema(self): |
| 1053 | + logger.warning("TestJsonDataSource.schema") |
| 1054 | + return "name STRING, age INT" |
| 1055 | + |
| 1056 | + def reader(self, schema) -> "DataSourceReader": |
| 1057 | + logger.warning(f"TestJsonDataSource.reader: {schema.fieldNames()}") |
| 1058 | + return TestJsonReader(self.options) |
| 1059 | + |
| 1060 | + self.spark.dataSource.register(TestJsonDataSource) |
| 1061 | + path1 = os.path.join(SPARK_HOME, "python/test_support/sql/people.json") |
| 1062 | + |
| 1063 | + with self.sql_conf( |
| 1064 | + { |
| 1065 | + "spark.sql.python.filterPushdown.enabled": "true", |
| 1066 | + "spark.sql.pyspark.worker.logging.enabled": "true", |
| 1067 | + } |
| 1068 | + ): |
| 1069 | + assertDataFrameEqual( |
| 1070 | + self.spark.read.format("my-json").load(path1).filter("age is not null"), |
| 1071 | + [ |
| 1072 | + Row(name="Andy", age=30), |
| 1073 | + Row(name="Justin", age=19), |
| 1074 | + ], |
| 1075 | + ) |
| 1076 | + |
| 1077 | + logs = self.spark.table("system.session.python_worker_logs") |
| 1078 | + |
| 1079 | + assertDataFrameEqual( |
| 1080 | + logs.select("level", "msg", "context", "logger"), |
| 1081 | + [ |
| 1082 | + Row( |
| 1083 | + level="WARNING", |
| 1084 | + msg=msg, |
| 1085 | + context=context, |
| 1086 | + logger="test_data_source_reader_pushdown", |
| 1087 | + ) |
| 1088 | + for msg, context in [ |
| 1089 | + ( |
| 1090 | + "TestJsonDataSource.__init__: ['path']", |
| 1091 | + {"class_name": "TestJsonDataSource", "func_name": "__init__"}, |
| 1092 | + ), |
| 1093 | + ( |
| 1094 | + "TestJsonDataSource.name", |
| 1095 | + {"class_name": "TestJsonDataSource", "func_name": "name"}, |
| 1096 | + ), |
| 1097 | + ( |
| 1098 | + "TestJsonDataSource.schema", |
| 1099 | + {"class_name": "TestJsonDataSource", "func_name": "schema"}, |
| 1100 | + ), |
| 1101 | + ( |
| 1102 | + "TestJsonDataSource.reader: ['name', 'age']", |
| 1103 | + {"class_name": "TestJsonDataSource", "func_name": "reader"}, |
| 1104 | + ), |
| 1105 | + ( |
| 1106 | + "TestJsonReader.pushFilters: [IsNotNull(attribute=('age',))]", |
| 1107 | + {"class_name": "TestJsonReader", "func_name": "pushFilters"}, |
| 1108 | + ), |
| 1109 | + ( |
| 1110 | + "TestJsonReader.__init__: ['path']", |
| 1111 | + {"class_name": "TestJsonDataSource", "func_name": "reader"}, |
| 1112 | + ), |
| 1113 | + ( |
| 1114 | + "TestJsonReader.partitions", |
| 1115 | + {"class_name": "TestJsonReader", "func_name": "partitions"}, |
| 1116 | + ), |
| 1117 | + ( |
| 1118 | + "TestJsonReader.read: None", |
| 1119 | + {"class_name": "TestJsonReader", "func_name": "read"}, |
| 1120 | + ), |
| 1121 | + ] |
| 1122 | + ], |
| 1123 | + ) |
| 1124 | + |
| 1125 | + @unittest.skipIf(is_remote_only(), "Requires JVM access") |
| 1126 | + def test_data_source_writer_with_logging(self): |
| 1127 | + logger = logging.getLogger("test_datasource_writer") |
| 1128 | + |
| 1129 | + @dataclass |
| 1130 | + class TestCommitMessage(WriterCommitMessage): |
| 1131 | + count: int |
| 1132 | + |
| 1133 | + class TestJsonWriter(DataSourceWriter): |
| 1134 | + def __init__(self, options): |
| 1135 | + logger.warning(f"TestJsonWriter.__init__: {list(options)}") |
| 1136 | + self.options = options |
| 1137 | + self.path = self.options.get("path") |
| 1138 | + |
| 1139 | + def write(self, iterator): |
| 1140 | + from pyspark import TaskContext |
| 1141 | + |
| 1142 | + if self.options.get("abort", None): |
| 1143 | + logger.warning("TestJsonWriter.write: abort test") |
| 1144 | + raise Exception("abort test") |
| 1145 | + |
| 1146 | + context = TaskContext.get() |
| 1147 | + output_path = os.path.join(self.path, f"{context.partitionId()}.json") |
| 1148 | + count = 0 |
| 1149 | + rows = [] |
| 1150 | + with open(output_path, "w") as file: |
| 1151 | + for row in iterator: |
| 1152 | + count += 1 |
| 1153 | + rows.append(row.asDict()) |
| 1154 | + file.write(json.dumps(row.asDict()) + "\n") |
| 1155 | + |
| 1156 | + logger.warning(f"TestJsonWriter.write: {count}, {rows}") |
| 1157 | + |
| 1158 | + return TestCommitMessage(count=count) |
| 1159 | + |
| 1160 | + def commit(self, messages): |
| 1161 | + total_count = sum(message.count for message in messages) |
| 1162 | + with open(os.path.join(self.path, "_success.txt"), "w") as file: |
| 1163 | + file.write(f"count: {total_count}\n") |
| 1164 | + |
| 1165 | + logger.warning(f"TestJsonWriter.commit: {total_count}") |
| 1166 | + |
| 1167 | + def abort(self, messages): |
| 1168 | + with open(os.path.join(self.path, "_failed.txt"), "w") as file: |
| 1169 | + file.write("failed") |
| 1170 | + |
| 1171 | + logger.warning("TestJsonWriter.abort") |
| 1172 | + |
| 1173 | + class TestJsonDataSource(DataSource): |
| 1174 | + @classmethod |
| 1175 | + def name(cls): |
| 1176 | + logger.warning("TestJsonDataSource.name") |
| 1177 | + return "my-json" |
| 1178 | + |
| 1179 | + def writer(self, schema, overwrite): |
| 1180 | + logger.warning(f"TestJsonDataSource.writer: {schema.fieldNames(), {overwrite}}") |
| 1181 | + return TestJsonWriter(self.options) |
| 1182 | + |
| 1183 | + # Register the data source |
| 1184 | + self.spark.dataSource.register(TestJsonDataSource) |
| 1185 | + |
| 1186 | + with tempfile.TemporaryDirectory(prefix="test_datasource_write_logging") as d: |
| 1187 | + with self.sql_conf({"spark.sql.pyspark.worker.logging.enabled": "true"}): |
| 1188 | + # Create a simple DataFrame and write it using our custom datasource |
| 1189 | + df = self.spark.createDataFrame( |
| 1190 | + [("Charlie", 35), ("Diana", 28)], "name STRING, age INT" |
| 1191 | + ).repartitionByRange(2, "age") |
| 1192 | + df.write.format("my-json").mode("overwrite").save(d) |
| 1193 | + |
| 1194 | + # Verify the write worked by checking the success file |
| 1195 | + with open(os.path.join(d, "_success.txt"), "r") as file: |
| 1196 | + text = file.read() |
| 1197 | + self.assertEqual(text, "count: 2\n") |
| 1198 | + |
| 1199 | + with self.assertRaises(Exception, msg="abort test"): |
| 1200 | + df.write.format("my-json").mode("append").option("abort", "true").save(d) |
| 1201 | + |
| 1202 | + logs = self.spark.table("system.session.python_worker_logs") |
| 1203 | + |
| 1204 | + assertDataFrameEqual( |
| 1205 | + logs.select("level", "msg", "context", "logger"), |
| 1206 | + [ |
| 1207 | + Row( |
| 1208 | + level="WARNING", |
| 1209 | + msg=msg, |
| 1210 | + context=context, |
| 1211 | + logger="test_datasource_writer", |
| 1212 | + ) |
| 1213 | + for msg, context in [ |
| 1214 | + ( |
| 1215 | + "TestJsonDataSource.name", |
| 1216 | + {"class_name": "TestJsonDataSource", "func_name": "name"}, |
| 1217 | + ), |
| 1218 | + ( |
| 1219 | + "TestJsonDataSource.writer: (['name', 'age'], {True})", |
| 1220 | + {"class_name": "TestJsonDataSource", "func_name": "writer"}, |
| 1221 | + ), |
| 1222 | + ( |
| 1223 | + "TestJsonWriter.__init__: ['path']", |
| 1224 | + {"class_name": "TestJsonDataSource", "func_name": "writer"}, |
| 1225 | + ), |
| 1226 | + ( |
| 1227 | + "TestJsonWriter.write: 1, [{'name': 'Diana', 'age': 28}]", |
| 1228 | + {"class_name": "TestJsonWriter", "func_name": "write"}, |
| 1229 | + ), |
| 1230 | + ( |
| 1231 | + "TestJsonWriter.write: 1, [{'name': 'Charlie', 'age': 35}]", |
| 1232 | + {"class_name": "TestJsonWriter", "func_name": "write"}, |
| 1233 | + ), |
| 1234 | + ( |
| 1235 | + "TestJsonWriter.commit: 2", |
| 1236 | + {"class_name": "TestJsonWriter", "func_name": "commit"}, |
| 1237 | + ), |
| 1238 | + ( |
| 1239 | + "TestJsonDataSource.name", |
| 1240 | + {"class_name": "TestJsonDataSource", "func_name": "name"}, |
| 1241 | + ), |
| 1242 | + ( |
| 1243 | + "TestJsonDataSource.writer: (['name', 'age'], {False})", |
| 1244 | + {"class_name": "TestJsonDataSource", "func_name": "writer"}, |
| 1245 | + ), |
| 1246 | + ( |
| 1247 | + "TestJsonWriter.__init__: ['abort', 'path']", |
| 1248 | + {"class_name": "TestJsonDataSource", "func_name": "writer"}, |
| 1249 | + ), |
| 1250 | + ( |
| 1251 | + "TestJsonWriter.write: abort test", |
| 1252 | + {"class_name": "TestJsonWriter", "func_name": "write"}, |
| 1253 | + ), |
| 1254 | + ( |
| 1255 | + "TestJsonWriter.write: abort test", |
| 1256 | + {"class_name": "TestJsonWriter", "func_name": "write"}, |
| 1257 | + ), |
| 1258 | + ( |
| 1259 | + "TestJsonWriter.abort", |
| 1260 | + {"class_name": "TestJsonWriter", "func_name": "abort"}, |
| 1261 | + ), |
| 1262 | + ] |
| 1263 | + ], |
| 1264 | + ) |
| 1265 | + |
910 | 1266 |
|
911 | 1267 | class PythonDataSourceTests(BasePythonDataSourceTestsMixin, ReusedSQLTestCase): |
912 | 1268 | ... |
|
0 commit comments