diff --git a/python/pyspark/sql/tests/test_udtf.py b/python/pyspark/sql/tests/test_udtf.py index 389df5b5a6cf..2a1134dc88e3 100644 --- a/python/pyspark/sql/tests/test_udtf.py +++ b/python/pyspark/sql/tests/test_udtf.py @@ -812,6 +812,111 @@ def _check_result_or_exception( with self.assertRaisesRegex(err_type, expected): func().collect() + def test_udtf_nullable_check(self): + for ret_type, value, expected in [ + ( + StructType([StructField("value", ArrayType(IntegerType(), False))]), + ([None],), + "PySparkRuntimeError", + ), + ( + StructType([StructField("value", ArrayType(IntegerType(), True))]), + ([None],), + [Row(value=[None])], + ), + ( + StructType([StructField("value", MapType(StringType(), IntegerType(), False))]), + ({"a": None},), + "PySparkRuntimeError", + ), + ( + StructType([StructField("value", MapType(StringType(), IntegerType(), True))]), + ({"a": None},), + [Row(value={"a": None})], + ), + ( + StructType([StructField("value", MapType(StringType(), IntegerType(), True))]), + ({None: 1},), + "PySparkRuntimeError", + ), + ( + StructType([StructField("value", MapType(StringType(), IntegerType(), False))]), + ({None: 1},), + "PySparkRuntimeError", + ), + ( + StructType( + [ + StructField( + "value", MapType(StringType(), ArrayType(IntegerType(), False), False) + ) + ] + ), + ({"s": [None]},), + "PySparkRuntimeError", + ), + ( + StructType( + [ + StructField( + "value", + MapType( + StructType([StructField("value", StringType(), False)]), + IntegerType(), + False, + ), + ) + ] + ), + ({(None,): 1},), + "PySparkRuntimeError", + ), + ( + StructType( + [ + StructField( + "value", + MapType( + StructType([StructField("value", StringType(), False)]), + IntegerType(), + True, + ), + ) + ] + ), + ({(None,): 1},), + "PySparkRuntimeError", + ), + ( + StructType( + [StructField("value", StructType([StructField("value", StringType(), False)]))] + ), + ((None,),), + "PySparkRuntimeError", + ), + ( + StructType( + [ + StructField( + "value", + StructType( + [StructField("value", ArrayType(StringType(), False), False)] + ), + ) + ] + ), + (([None],),), + "PySparkRuntimeError", + ), + ]: + + class TestUDTF: + def eval(self): + yield value + + with self.subTest(ret_type=ret_type, value=value): + self._check_result_or_exception(TestUDTF, ret_type, expected) + def test_numeric_output_type_casting(self): class TestUDTF: def eval(self):