Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions dev/sparktestsupport/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,6 +452,7 @@ def __hash__(self):
"pyspark.sql.tests.test_group",
"pyspark.sql.tests.test_pandas_cogrouped_map",
"pyspark.sql.tests.test_pandas_grouped_map",
"pyspark.sql.tests.test_pandas_grouped_map_with_state",
"pyspark.sql.tests.test_pandas_map",
"pyspark.sql.tests.test_arrow_map",
"pyspark.sql.tests.test_pandas_udf",
Expand Down
103 changes: 103 additions & 0 deletions python/pyspark/sql/tests/test_pandas_grouped_map_with_state.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import unittest
from typing import cast

from pyspark.sql.streaming.state import GroupStateTimeout, GroupState
from pyspark.sql.types import (
LongType,
StringType,
StructType,
StructField,
Row,
)
from pyspark.testing.sqlutils import (
ReusedSQLTestCase,
have_pandas,
have_pyarrow,
pandas_requirement_message,
pyarrow_requirement_message,
)

if have_pandas:
import pandas as pd

if have_pyarrow:
import pyarrow as pa # noqa: F401


@unittest.skipIf(
not have_pandas or not have_pyarrow,
cast(str, pandas_requirement_message or pyarrow_requirement_message),
)
class GroupedMapInPandasWithStateTests(ReusedSQLTestCase):
def test_apply_in_pandas_with_state_basic(self):
df = self.spark.readStream.format("text").load("python/test_support/sql/streaming")

for q in self.spark.streams.active:
q.stop()
self.assertTrue(df.isStreaming)

output_type = StructType(
[StructField("key", StringType()), StructField("countAsString", StringType())]
)
state_type = StructType([StructField("c", LongType())])

def func(key, pdf_iter, state):
assert isinstance(state, GroupState)

total_len = 0
for pdf in pdf_iter:
total_len += len(pdf)

state.update((total_len,))
assert state.get[0] == 1
yield pd.DataFrame({"key": [key[0]], "countAsString": [str(total_len)]})

def check_results(batch_df, _):
self.assertEqual(
set(batch_df.collect()),
{Row(key="hello", countAsString="1"), Row(key="this", countAsString="1")},
)

q = (
df.groupBy(df["value"])
.applyInPandasWithState(
func, output_type, state_type, "Update", GroupStateTimeout.NoTimeout
)
.writeStream.queryName("this_query")
.foreachBatch(check_results)
.outputMode("update")
.start()
)

self.assertEqual(q.name, "this_query")
self.assertTrue(q.isActive)
q.processAllAvailable()


if __name__ == "__main__":
from pyspark.sql.tests.test_pandas_grouped_map_with_state import * # noqa: F401

try:
import xmlrunner # type: ignore[import]

testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
testRunner = None
unittest.main(testRunner=testRunner, verbosity=2)
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.spark.sql

import java.nio.charset.StandardCharsets
import java.nio.file.{Files, Paths}

import scala.collection.JavaConverters._
Expand All @@ -31,7 +32,7 @@ import org.apache.spark.sql.catalyst.expressions.{Cast, Expression, ExprId, Pyth
import org.apache.spark.sql.catalyst.plans.SQLHelper
import org.apache.spark.sql.execution.python.UserDefinedPythonFunction
import org.apache.spark.sql.expressions.SparkUserDefinedFunction
import org.apache.spark.sql.types.{DataType, IntegerType, StringType}
import org.apache.spark.sql.types.{DataType, IntegerType, NullType, StringType}

/**
* This object targets to integrate various UDF test cases so that Scalar UDF, Python UDF,
Expand Down Expand Up @@ -190,7 +191,7 @@ object IntegratedUDFTestUtils extends SQLHelper {
throw new RuntimeException(s"Python executable [$pythonExec] and/or pyspark are unavailable.")
}

private lazy val pandasFunc: Array[Byte] = if (shouldTestScalarPandasUDFs) {
private lazy val pandasFunc: Array[Byte] = if (shouldTestPandasUDFs) {
var binaryPandasFunc: Array[Byte] = null
withTempPath { path =>
Process(
Expand All @@ -213,7 +214,7 @@ object IntegratedUDFTestUtils extends SQLHelper {
throw new RuntimeException(s"Python executable [$pythonExec] and/or pyspark are unavailable.")
}

private lazy val pandasGroupedAggFunc: Array[Byte] = if (shouldTestGroupedAggPandasUDFs) {
private lazy val pandasGroupedAggFunc: Array[Byte] = if (shouldTestPandasUDFs) {
var binaryPandasFunc: Array[Byte] = null
withTempPath { path =>
Process(
Expand All @@ -235,6 +236,33 @@ object IntegratedUDFTestUtils extends SQLHelper {
throw new RuntimeException(s"Python executable [$pythonExec] and/or pyspark are unavailable.")
}

private def createPandasGroupedMapFuncWithState(pythonScript: String): Array[Byte] = {
if (shouldTestPandasUDFs) {
var binaryPandasFunc: Array[Byte] = null
withTempPath { codePath =>
Files.write(codePath.toPath, pythonScript.getBytes(StandardCharsets.UTF_8))
withTempPath { path =>
Process(
Seq(
pythonExec,
"-c",
"from pyspark.serializers import CloudPickleSerializer; " +
s"f = open('$path', 'wb');" +
s"exec(open('$codePath', 'r').read());" +
"f.write(CloudPickleSerializer().dumps((" +
"func, tpe)))"),
None,
"PYTHONPATH" -> s"$pysparkPythonPath:$pythonPath").!!
binaryPandasFunc = Files.readAllBytes(path.toPath)
}
}
assert(binaryPandasFunc != null)
binaryPandasFunc
} else {
throw new RuntimeException(s"Python executable [$pythonExec] and/or pyspark are unavailable.")
}
}

// Make sure this map stays mutable - this map gets updated later in Python runners.
private val workerEnv = new java.util.HashMap[String, String]()
workerEnv.put("PYTHONPATH", s"$pysparkPythonPath:$pythonPath")
Expand All @@ -251,11 +279,9 @@ object IntegratedUDFTestUtils extends SQLHelper {

lazy val shouldTestPythonUDFs: Boolean = isPythonAvailable && isPySparkAvailable

lazy val shouldTestScalarPandasUDFs: Boolean =
lazy val shouldTestPandasUDFs: Boolean =
isPythonAvailable && isPandasAvailable && isPyArrowAvailable

lazy val shouldTestGroupedAggPandasUDFs: Boolean = shouldTestScalarPandasUDFs

/**
* A base trait for various UDFs defined in this object.
*/
Expand Down Expand Up @@ -420,6 +446,41 @@ object IntegratedUDFTestUtils extends SQLHelper {
val prettyName: String = "Grouped Aggregate Pandas UDF"
}

/**
* Arbitrary stateful processing in Python is used for
* `DataFrame.groupBy.applyInPandasWithState`. It requires `pythonScript` to
* define `func` (Python function) and `tpe` (`StructType` for state key).
*
* Virtually equivalent to:
*
* {{{
* # exec defines 'func' and 'tpe' (struct type for state key)
* exec(pythonScript)
*
* # ... are filled when this UDF is invoked, see also 'PythonFlatMapGroupsWithStateSuite'.
* df.groupBy(...).applyInPandasWithState(func, ..., tpe, ..., ...)
* }}}
*/
case class TestGroupedMapPandasUDFWithState(name: String, pythonScript: String) extends TestUDF {
private[IntegratedUDFTestUtils] lazy val udf = new UserDefinedPythonFunction(
name = name,
func = SimplePythonFunction(
command = createPandasGroupedMapFuncWithState(pythonScript),
envVars = workerEnv.clone().asInstanceOf[java.util.Map[String, String]],
pythonIncludes = List.empty[String].asJava,
pythonExec = pythonExec,
pythonVer = pythonVer,
broadcastVars = List.empty[Broadcast[PythonBroadcast]].asJava,
accumulator = null),
dataType = NullType, // This is not respected.
pythonEvalType = PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF,
udfDeterministic = true)

def apply(exprs: Column*): Column = udf(exprs: _*)

val prettyName: String = "Grouped Map Pandas UDF with State"
}

/**
* A Scala UDF that takes one column, casts into string, executes the
* Scala native function, and casts back to the type of input column.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -251,14 +251,14 @@ class SQLQueryTestSuite extends QueryTest with SharedSparkSession with SQLHelper
/* Do nothing */
}
case udfTestCase: UDFTest
if udfTestCase.udf.isInstanceOf[TestScalarPandasUDF] && !shouldTestScalarPandasUDFs =>
if udfTestCase.udf.isInstanceOf[TestScalarPandasUDF] && !shouldTestPandasUDFs =>
ignore(s"${testCase.name} is skipped because pyspark," +
s"pandas and/or pyarrow were not available in [$pythonExec].") {
/* Do nothing */
}
case udfTestCase: UDFTest
if udfTestCase.udf.isInstanceOf[TestGroupedAggPandasUDF] &&
!shouldTestGroupedAggPandasUDFs =>
!shouldTestPandasUDFs =>
ignore(s"${testCase.name} is skipped because pyspark," +
s"pandas and/or pyarrow were not available in [$pythonExec].") {
/* Do nothing */
Expand Down Expand Up @@ -447,12 +447,12 @@ class SQLQueryTestSuite extends QueryTest with SharedSparkSession with SQLHelper
if udfTestCase.udf.isInstanceOf[TestPythonUDF] && shouldTestPythonUDFs =>
s"${testCase.name}${System.lineSeparator()}Python: $pythonVer${System.lineSeparator()}"
case udfTestCase: UDFTest
if udfTestCase.udf.isInstanceOf[TestScalarPandasUDF] && shouldTestScalarPandasUDFs =>
if udfTestCase.udf.isInstanceOf[TestScalarPandasUDF] && shouldTestPandasUDFs =>
s"${testCase.name}${System.lineSeparator()}" +
s"Python: $pythonVer Pandas: $pandasVer PyArrow: $pyarrowVer${System.lineSeparator()}"
case udfTestCase: UDFTest
if udfTestCase.udf.isInstanceOf[TestGroupedAggPandasUDF] &&
shouldTestGroupedAggPandasUDFs =>
shouldTestPandasUDFs =>
s"${testCase.name}${System.lineSeparator()}" +
s"Python: $pythonVer Pandas: $pandasVer PyArrow: $pyarrowVer${System.lineSeparator()}"
case _ =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ class QueryCompilationErrorsSuite

test("INVALID_PANDAS_UDF_PLACEMENT: Using aggregate function with grouped aggregate pandas UDF") {
import IntegratedUDFTestUtils._
assume(shouldTestGroupedAggPandasUDFs)
assume(shouldTestPandasUDFs)

val df = Seq(
(536361, "85123A", 2, 17850),
Expand Down Expand Up @@ -180,7 +180,7 @@ class QueryCompilationErrorsSuite

test("UNSUPPORTED_FEATURE: Using pandas UDF aggregate expression with pivot") {
import IntegratedUDFTestUtils._
assume(shouldTestGroupedAggPandasUDFs)
assume(shouldTestPandasUDFs)

val df = Seq(
(536361, "85123A", 2, 17850),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ class PythonUDFSuite extends QueryTest with SharedSparkSession {
}

test("SPARK-39962: Global aggregation of Pandas UDF should respect the column order") {
assume(shouldTestGroupedAggPandasUDFs)
assume(shouldTestPythonUDFs)
val df = Seq[(java.lang.Integer, java.lang.Integer)]((1, null)).toDF("a", "b")

val pandasTestUDF = TestGroupedAggPandasUDF(name = "pandas_udf")
Expand Down
Loading