@@ -558,17 +558,40 @@ def check_datatype(datatype):
558558 _verify_type (PythonOnlyPoint (1.0 , 2.0 ), PythonOnlyUDT ())
559559 self .assertRaises (ValueError , lambda : _verify_type ([1.0 , 2.0 ], PythonOnlyUDT ()))
560560
561+ def test_simple_udt_in_df (self ):
561562 schema = StructType ().add ("key" , LongType ()).add ("val" , PythonOnlyUDT ())
562563 df = self .spark .createDataFrame (
563564 [(i % 3 , PythonOnlyPoint (float (i ), float (i ))) for i in range (10 )],
564565 schema = schema )
565566 df .show ()
566567
568+ def test_nested_udt_in_df (self ):
567569 schema = StructType ().add ("key" , LongType ()).add ("val" , ArrayType (PythonOnlyUDT ()))
568570 df = self .spark .createDataFrame (
569- [(i % 3 , [PythonOnlyPoint (float (i ), float (i ))]) for i in range (10 )],
571+ [(i % 3 , [PythonOnlyPoint (float (i ), float (i ))]) for i in range (10 )],
572+ schema = schema )
573+ df .collect ()
574+
575+ schema = StructType ().add ("key" , LongType ()).add ("val" ,
576+ MapType (LongType (), PythonOnlyUDT ()))
577+ df = self .spark .createDataFrame (
578+ [(i % 3 , {i % 3 : PythonOnlyPoint (float (i + 1 ), float (i + 1 ))}) for i in range (10 )],
579+ schema = schema )
580+ df .collect ()
581+
582+ def test_complex_nested_udt_in_df (self ):
583+ from pyspark .sql .functions import udf
584+
585+ schema = StructType ().add ("key" , LongType ()).add ("val" , PythonOnlyUDT ())
586+ df = self .spark .createDataFrame (
587+ [(i % 3 , PythonOnlyPoint (float (i ), float (i ))) for i in range (10 )],
570588 schema = schema )
571- df .show ()
589+ df .collect ()
590+
591+ gd = df .groupby ("key" ).agg ({"val" : "collect_list" })
592+ gd .collect ()
593+ udf = udf (lambda k , v : [(k , v [0 ])], ArrayType (df .schema ))
594+ gd .select (udf (* gd )).collect ()
572595
573596 def test_infer_schema_with_udt (self ):
574597 from pyspark .sql .tests import ExamplePoint , ExamplePointUDT
0 commit comments