Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 46 additions & 0 deletions python/pyspark/sql/tests/arrow/test_arrow_cogrouped_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import os
import time
import unittest
import logging

from pyspark.errors import PythonException
from pyspark.sql import Row
Expand All @@ -26,6 +27,8 @@
have_pyarrow,
pyarrow_requirement_message,
)
from pyspark.testing.utils import assertDataFrameEqual
from pyspark.util import is_remote_only

if have_pyarrow:
import pyarrow as pa
Expand Down Expand Up @@ -367,6 +370,49 @@ def test_negative_and_zero_batch_size(self):
with self.sql_conf({"spark.sql.execution.arrow.maxRecordsPerBatch": batch_size}):
CogroupedMapInArrowTestsMixin.test_apply_in_arrow(self)

@unittest.skipIf(is_remote_only(), "Requires JVM access")
def test_cogroup_apply_in_arrow_with_logging(self):
import pyarrow as pa

def func_with_logging(left, right):
assert isinstance(left, pa.Table)
assert isinstance(right, pa.Table)
logger = logging.getLogger("test_arrow_cogrouped_map")
logger.warning(
"arrow cogrouped map: "
+ f"{dict(v1=left['v1'].to_pylist(), v2=right['v2'].to_pylist())}"
)
return left.join(right, keys="id", join_type="inner")

left_df = self.spark.createDataFrame([(1, 10), (2, 20), (1, 30)], ["id", "v1"])
right_df = self.spark.createDataFrame([(1, 100), (2, 200), (1, 300)], ["id", "v2"])

grouped_left = left_df.groupBy("id")
grouped_right = right_df.groupBy("id")
cogrouped_df = grouped_left.cogroup(grouped_right)

with self.sql_conf({"spark.sql.pyspark.worker.logging.enabled": "true"}):
assertDataFrameEqual(
cogrouped_df.applyInArrow(func_with_logging, "id long, v1 long, v2 long"),
[Row(id=1, v1=v1, v2=v2) for v1 in [10, 30] for v2 in [100, 300]]
+ [Row(id=2, v1=20, v2=200)],
)

logs = self.spark.table("system.session.python_worker_logs")

assertDataFrameEqual(
logs.select("level", "msg", "context", "logger"),
[
Row(
level="WARNING",
msg=f"arrow cogrouped map: {dict(v1=v1, v2=v2)}",
context={"func_name": func_with_logging.__name__},
logger="test_arrow_cogrouped_map",
)
for v1, v2 in [([10, 30], [100, 300]), ([20], [200])]
],
)


class CogroupedMapInArrowTests(CogroupedMapInArrowTestsMixin, ReusedSQLTestCase):
@classmethod
Expand Down
77 changes: 77 additions & 0 deletions python/pyspark/sql/tests/arrow/test_arrow_grouped_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import inspect
import os
import time
import logging
from typing import Iterator, Tuple
import unittest

Expand All @@ -29,6 +30,8 @@
have_pyarrow,
pyarrow_requirement_message,
)
from pyspark.testing.utils import assertDataFrameEqual
from pyspark.util import is_remote_only

if have_pyarrow:
import pyarrow as pa
Expand Down Expand Up @@ -394,6 +397,80 @@ def test_negative_and_zero_batch_size(self):
with self.sql_conf({"spark.sql.execution.arrow.maxRecordsPerBatch": batch_size}):
ApplyInArrowTestsMixin.test_apply_in_arrow(self)

@unittest.skipIf(is_remote_only(), "Requires JVM access")
def test_apply_in_arrow_with_logging(self):
import pyarrow as pa

def func_with_logging(group):
assert isinstance(group, pa.Table)
logger = logging.getLogger("test_arrow_grouped_map")
logger.warning(f"arrow grouped map: {group.to_pydict()}")
return group

df = self.spark.range(9).withColumn("value", col("id") * 10)
grouped_df = df.groupBy((col("id") % 2).cast("int"))

with self.sql_conf({"spark.sql.pyspark.worker.logging.enabled": "true"}):
assertDataFrameEqual(
grouped_df.applyInArrow(func_with_logging, "id long, value long"),
df,
)

logs = self.spark.table("system.session.python_worker_logs")

assertDataFrameEqual(
logs.select("level", "msg", "context", "logger"),
[
Row(
level="WARNING",
msg=f"arrow grouped map: {dict(id=lst, value=[v*10 for v in lst])}",
context={"func_name": func_with_logging.__name__},
logger="test_arrow_grouped_map",
)
for lst in [[0, 2, 4, 6, 8], [1, 3, 5, 7]]
],
)

@unittest.skipIf(is_remote_only(), "Requires JVM access")
def test_apply_in_arrow_iter_with_logging(self):
import pyarrow as pa

def func_with_logging(group: Iterator[pa.RecordBatch]) -> Iterator[pa.RecordBatch]:
logger = logging.getLogger("test_arrow_grouped_map")
for batch in group:
assert isinstance(batch, pa.RecordBatch)
logger.warning(f"arrow grouped map: {batch.to_pydict()}")
yield batch

df = self.spark.range(9).withColumn("value", col("id") * 10)
grouped_df = df.groupBy((col("id") % 2).cast("int"))

with self.sql_conf(
{
"spark.sql.execution.arrow.maxRecordsPerBatch": 3,
"spark.sql.pyspark.worker.logging.enabled": "true",
}
):
assertDataFrameEqual(
grouped_df.applyInArrow(func_with_logging, "id long, value long"),
df,
)

logs = self.spark.table("system.session.python_worker_logs")

assertDataFrameEqual(
logs.select("level", "msg", "context", "logger"),
[
Row(
level="WARNING",
msg=f"arrow grouped map: {dict(id=lst, value=[v*10 for v in lst])}",
context={"func_name": func_with_logging.__name__},
logger="test_arrow_grouped_map",
)
for lst in [[0, 2, 4], [6, 8], [1, 3, 5], [7]]
],
)


class ApplyInArrowTests(ApplyInArrowTestsMixin, ReusedSQLTestCase):
@classmethod
Expand Down
55 changes: 55 additions & 0 deletions python/pyspark/sql/tests/arrow/test_arrow_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import os
import time
import unittest
import logging

from pyspark.sql.utils import PythonException
from pyspark.testing.sqlutils import (
Expand All @@ -26,6 +27,9 @@
pandas_requirement_message,
pyarrow_requirement_message,
)
from pyspark.sql import Row
from pyspark.testing.utils import assertDataFrameEqual
from pyspark.util import is_remote_only

if have_pyarrow:
import pyarrow as pa
Expand Down Expand Up @@ -221,6 +225,46 @@ def func(iterator):
df = self.spark.range(1)
df.mapInArrow(func, "a int").collect()

@unittest.skipIf(is_remote_only(), "Requires JVM access")
def test_map_in_arrow_with_logging(self):
import pyarrow as pa

def func_with_logging(iterator):
logger = logging.getLogger("test_arrow_map")
for batch in iterator:
assert isinstance(batch, pa.RecordBatch)
logger.warning(f"arrow map: {batch.to_pydict()}")
yield batch

with self.sql_conf(
{
"spark.sql.execution.arrow.maxRecordsPerBatch": "3",
"spark.sql.pyspark.worker.logging.enabled": "true",
}
):
assertDataFrameEqual(
self.spark.range(9, numPartitions=2).mapInArrow(func_with_logging, "id long"),
[Row(id=i) for i in range(9)],
)

logs = self.spark.table("system.session.python_worker_logs")

assertDataFrameEqual(
logs.select("level", "msg", "context", "logger"),
self._expected_logs_for_test_map_in_arrow_with_logging(func_with_logging.__name__),
)

def _expected_logs_for_test_map_in_arrow_with_logging(self, func_name):
return [
Row(
level="WARNING",
msg=f"arrow map: {dict(id=lst)}",
context={"func_name": func_name},
logger="test_arrow_map",
)
for lst in [[0, 1, 2], [3], [4, 5, 6], [7, 8]]
]


class MapInArrowTests(MapInArrowTestsMixin, ReusedSQLTestCase):
@classmethod
Expand Down Expand Up @@ -253,6 +297,17 @@ def setUpClass(cls):
cls.spark.conf.set("spark.sql.execution.arrow.maxRecordsPerBatch", "3")
cls.spark.conf.set("spark.sql.execution.arrow.maxBytesPerBatch", "10")

def _expected_logs_for_test_map_in_arrow_with_logging(self, func_name):
return [
Row(
level="WARNING",
msg=f"arrow map: {dict(id=[i])}",
context={"func_name": func_name},
logger="test_arrow_map",
)
for i in range(9)
]


class MapInArrowWithOutputArrowBatchSlicingRecordsTests(MapInArrowTests):
@classmethod
Expand Down
12 changes: 0 additions & 12 deletions python/pyspark/sql/tests/arrow/test_arrow_python_udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,18 +60,6 @@ def test_register_java_function(self):
def test_register_java_udaf(self):
super(ArrowPythonUDFTests, self).test_register_java_udaf()

@unittest.skip(
"TODO(SPARK-53976): Python worker logging is not supported for Arrow Python UDFs."
)
def test_udf_with_logging(self):
super().test_udf_with_logging()

@unittest.skip(
"TODO(SPARK-53976): Python worker logging is not supported for Arrow Python UDFs."
)
def test_multiple_udfs_with_logging(self):
super().test_multiple_udfs_with_logging()

def test_complex_input_types(self):
row = (
self.spark.range(1)
Expand Down
40 changes: 39 additions & 1 deletion python/pyspark/sql/tests/arrow/test_arrow_udf_grouped_agg.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,10 @@
#

import unittest
import logging

from pyspark.sql.functions import arrow_udf, ArrowUDFType
from pyspark.util import PythonEvalType
from pyspark.util import PythonEvalType, is_remote_only
from pyspark.sql import Row
from pyspark.sql.types import (
ArrayType,
Expand All @@ -35,6 +36,7 @@
numpy_requirement_message,
have_pyarrow,
pyarrow_requirement_message,
assertDataFrameEqual,
)
from pyspark.testing.sqlutils import ReusedSQLTestCase

Expand Down Expand Up @@ -1021,6 +1023,42 @@ def arrow_max(v):

self.assertEqual(expected, result)

@unittest.skipIf(is_remote_only(), "Requires JVM access")
def test_grouped_agg_arrow_udf_with_logging(self):
import pyarrow as pa

@arrow_udf("double", ArrowUDFType.GROUPED_AGG)
def my_grouped_agg_arrow_udf(x):
assert isinstance(x, pa.Array)
logger = logging.getLogger("test_grouped_agg_arrow")
logger.warning(f"grouped agg arrow udf: {len(x)}")
return pa.compute.sum(x)

df = self.spark.createDataFrame(
[(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)], ("id", "v")
)

with self.sql_conf({"spark.sql.pyspark.worker.logging.enabled": "true"}):
assertDataFrameEqual(
df.groupby("id").agg(my_grouped_agg_arrow_udf("v").alias("result")),
[Row(id=1, result=3.0), Row(id=2, result=18.0)],
)

logs = self.spark.table("system.session.python_worker_logs")

assertDataFrameEqual(
logs.select("level", "msg", "context", "logger"),
[
Row(
level="WARNING",
msg=f"grouped agg arrow udf: {n}",
context={"func_name": my_grouped_agg_arrow_udf.__name__},
logger="test_grouped_agg_arrow",
)
for n in [2, 3]
],
)


class GroupedAggArrowUDFTests(GroupedAggArrowUDFTestsMixin, ReusedSQLTestCase):
pass
Expand Down
Loading