Skip to content

Commit d603cc2

Browse files
committed
Create new unit tests.
1 parent fc9c106 commit d603cc2

File tree

1 file changed

+25
-2
lines changed

1 file changed

+25
-2
lines changed

python/pyspark/sql/tests.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -558,17 +558,40 @@ def check_datatype(datatype):
558558
_verify_type(PythonOnlyPoint(1.0, 2.0), PythonOnlyUDT())
559559
self.assertRaises(ValueError, lambda: _verify_type([1.0, 2.0], PythonOnlyUDT()))
560560

561+
def test_simple_udt_in_df(self):
561562
schema = StructType().add("key", LongType()).add("val", PythonOnlyUDT())
562563
df = self.spark.createDataFrame(
563564
[(i % 3, PythonOnlyPoint(float(i), float(i))) for i in range(10)],
564565
schema=schema)
565566
df.show()
566567

568+
def test_nested_udt_in_df(self):
567569
schema = StructType().add("key", LongType()).add("val", ArrayType(PythonOnlyUDT()))
568570
df = self.spark.createDataFrame(
569-
[(i % 3, [PythonOnlyPoint(float(i), float(i))]) for i in range(10)],
571+
[(i % 3, [PythonOnlyPoint(float(i), float(i))]) for i in range(10)],
572+
schema=schema)
573+
df.collect()
574+
575+
schema = StructType().add("key", LongType()).add("val",
576+
MapType(LongType(), PythonOnlyUDT()))
577+
df = self.spark.createDataFrame(
578+
[(i % 3, {i % 3: PythonOnlyPoint(float(i + 1), float(i + 1))}) for i in range(10)],
579+
schema=schema)
580+
df.collect()
581+
582+
def test_complex_nested_udt_in_df(self):
583+
from pyspark.sql.functions import udf
584+
585+
schema = StructType().add("key", LongType()).add("val", PythonOnlyUDT())
586+
df = self.spark.createDataFrame(
587+
[(i % 3, PythonOnlyPoint(float(i), float(i))) for i in range(10)],
570588
schema=schema)
571-
df.show()
589+
df.collect()
590+
591+
gd = df.groupby("key").agg({"val": "collect_list"})
592+
gd.collect()
593+
udf = udf(lambda k, v: [(k, v[0])], ArrayType(df.schema))
594+
gd.select(udf(*gd)).collect()
572595

573596
def test_infer_schema_with_udt(self):
574597
from pyspark.sql.tests import ExamplePoint, ExamplePointUDT

0 commit comments

Comments
 (0)