@@ -296,6 +296,7 @@ def tearDown(self):
296296 # tear down test_bucketed_write state
297297 self .spark .sql ("DROP TABLE IF EXISTS pyspark_bucket" )
298298
299+ '''
299300 def test_row_should_be_read_only(self):
300301 row = Row(a=1, b=2)
301302 self.assertEqual(1, row.a)
@@ -1884,7 +1885,164 @@ def test_query_manager_await_termination(self):
18841885 finally:
18851886 q.stop()
18861887 shutil.rmtree(tmpPath)
1888+ '''
18871889
1890+ class ForeachWriterTester :
1891+
1892+ def __init__ (self , spark ):
1893+ self .spark = spark
1894+ self .input_dir = tempfile .mkdtemp ()
1895+ self .open_events_dir = tempfile .mkdtemp ()
1896+ self .process_events_dir = tempfile .mkdtemp ()
1897+ self .close_events_dir = tempfile .mkdtemp ()
1898+
1899+ def write_open_event (self , partitionId , epochId ):
1900+ self ._write_event (
1901+ self .open_events_dir ,
1902+ {'partition' : partitionId , 'epoch' : epochId })
1903+
1904+ def write_process_event (self , row ):
1905+ self ._write_event (self .process_events_dir , {'value' : 'text' })
1906+
1907+ def write_close_event (self , error ):
1908+ self ._write_event (self .close_events_dir , {'error' : str (error )})
1909+
1910+ def write_input_file (self ):
1911+ self ._write_event (self .input_dir , "text" )
1912+
1913+ def open_events (self ):
1914+ return self ._read_events (self .open_events_dir , 'partition INT, epoch INT' )
1915+
1916+ def process_events (self ):
1917+ return self ._read_events (self .process_events_dir , 'value STRING' )
1918+
1919+ def close_events (self ):
1920+ return self ._read_events (self .close_events_dir , 'error STRING' )
1921+
1922+ def run_streaming_query_on_writer (self , writer , num_files ):
1923+ try :
1924+ sdf = self .spark .readStream .format ('text' ).load (self .input_dir )
1925+ sq = sdf .writeStream .foreach (writer ).start ()
1926+ for i in range (num_files ):
1927+ self .write_input_file ()
1928+ sq .processAllAvailable ()
1929+ sq .stop ()
1930+ finally :
1931+ self .stop_all ()
1932+
1933+ def _read_events (self , dir , json ):
1934+ rows = self .spark .read .schema (json ).json (dir ).collect ()
1935+ dicts = [row .asDict () for row in rows ]
1936+ return dicts
1937+
1938+ def _write_event (self , dir , event ):
1939+ import random
1940+ file = open (os .path .join (dir , str (random .randint (0 , 100000 ))), 'w' )
1941+ file .write ("%s\n " % str (event ))
1942+ file .close ()
1943+
1944+ def stop_all (self ):
1945+ for q in self .spark ._wrapped .streams .active :
1946+ q .stop ()
1947+
1948+ def __getstate__ (self ):
1949+ return (self .open_events_dir , self .process_events_dir , self .close_events_dir )
1950+
1951+ def __setstate__ (self , state ):
1952+ self .open_events_dir , self .process_events_dir , self .close_events_dir = state
1953+
1954+ def test_streaming_foreach_with_simple_function (self ):
1955+ tester = self .ForeachWriterTester (self .spark )
1956+
1957+ def foreach_func (row ):
1958+ tester .write_process_event (row )
1959+
1960+ tester .run_streaming_query_on_writer (foreach_func , 2 )
1961+ self .assertEqual (len (tester .process_events ()), 2 )
1962+
1963+ def test_streaming_foreach_with_basic_open_process_close (self ):
1964+ tester = self .ForeachWriterTester (self .spark )
1965+
1966+ class ForeachWriter :
1967+ def open (self , partitionId , epochId ):
1968+ tester .write_open_event (partitionId , epochId )
1969+ return True
1970+
1971+ def process (self , row ):
1972+ tester .write_process_event (row )
1973+
1974+ def close (self , error ):
1975+ tester .write_close_event (error )
1976+
1977+ tester .run_streaming_query_on_writer (ForeachWriter (), 2 )
1978+
1979+ open_events = tester .open_events ()
1980+ self .assertEqual (len (open_events ), 2 )
1981+ self .assertSetEqual (set ([e ['epoch' ] for e in open_events ]), {0 , 1 })
1982+
1983+ self .assertEqual (len (tester .process_events ()), 2 )
1984+
1985+ close_events = tester .close_events ()
1986+ self .assertEqual (len (close_events ), 2 )
1987+ self .assertSetEqual (set ([e ['error' ] for e in close_events ]), {'None' })
1988+
1989+ def test_streaming_foreach_with_open_returning_false (self ):
1990+ tester = self .ForeachWriterTester (self .spark )
1991+
1992+ class ForeachWriter :
1993+ def open (self , partitionId , epochId ):
1994+ tester .write_open_event (partitionId , epochId )
1995+ return False
1996+
1997+ def process (self , row ):
1998+ tester .write_process_event (row )
1999+
2000+ def close (self , error ):
2001+ tester .write_close_event (error )
2002+
2003+ tester .run_streaming_query_on_writer (ForeachWriter (), 2 )
2004+
2005+ self .assertEqual (len (tester .open_events ()), 2 )
2006+ self .assertEqual (len (tester .process_events ()), 0 ) # no row was processed
2007+ close_events = tester .close_events ()
2008+ self .assertEqual (len (close_events ), 2 )
2009+ self .assertSetEqual (set ([e ['error' ] for e in close_events ]), {'None' })
2010+
2011+ def test_streaming_foreach_with_process_throwing_error (self ):
2012+ from pyspark .sql .utils import StreamingQueryException
2013+
2014+ tester = self .ForeachWriterTester (self .spark )
2015+
2016+ class ForeachWriter :
2017+ def open (self , partitionId , epochId ):
2018+ tester .write_open_event (partitionId , epochId )
2019+ return True
2020+
2021+ def process (self , row ):
2022+ raise Exception ("test error" )
2023+
2024+ def close (self , error ):
2025+ tester .write_close_event (error )
2026+
2027+ try :
2028+ sdf = self .spark .readStream .format ('text' ).load (tester .input_dir )
2029+ sq = sdf .writeStream .foreach (ForeachWriter ()).start ()
2030+ tester .write_input_file ()
2031+ sq .processAllAvailable ()
2032+ self .fail ("bad writer should fail the query" ) # this is not expected
2033+ except StreamingQueryException as e :
2034+ # self.assertTrue("test error" in e.desc) # this is expected
2035+ pass
2036+ finally :
2037+ tester .stop_all ()
2038+
2039+ self .assertEqual (len (tester .open_events ()), 1 )
2040+ self .assertEqual (len (tester .process_events ()), 0 ) # no row was processed
2041+ close_events = tester .close_events ()
2042+ self .assertEqual (len (close_events ), 1 )
2043+ # self.assertTrue("test error" in e[0]['error'])
2044+
2045+ '''
18882046 def test_help_command(self):
18892047 # Regression test for SPARK-5464
18902048 rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}'])
@@ -5391,7 +5549,7 @@ def test_invalid_args(self):
53915549 AnalysisException,
53925550 'mixture.*aggregate function.*group aggregate pandas UDF'):
53935551 df.groupby(df.id).agg(mean_udf(df.v), mean(df.v)).collect()
5394-
5552+ '''
53955553if __name__ == "__main__" :
53965554 from pyspark .sql .tests import *
53975555 if xmlrunner :
0 commit comments