3434else :
3535 import unittest
3636
37- from pyspark .sql import SQLContext , IntegerType , Row , ArrayType , StructType , StructField , \
38- UserDefinedType , DoubleType
37+ from pyspark .sql import SQLContext , HiveContext , IntegerType , Row , ArrayType , StructType ,\
38+ StructField , UserDefinedType , DoubleType
3939from pyspark .tests import ReusedPySparkTestCase
4040
4141
@@ -285,6 +285,38 @@ def test_aggregator(self):
285285 self .assertTrue (95 < g .agg (Dsl .approxCountDistinct (df .key )).first ()[0 ])
286286 self .assertEqual (100 , g .agg (Dsl .countDistinct (df .value )).first ()[0 ])
287287
288+ def test_save_and_load (self ):
289+ df = self .df
290+ tmpPath = tempfile .mkdtemp ()
291+ shutil .rmtree (tmpPath )
292+ df .save (tmpPath , "org.apache.spark.sql.json" , "error" )
293+ actual = self .sqlCtx .load (tmpPath , "org.apache.spark.sql.json" )
294+ self .assertTrue (sorted (df .collect ()) == sorted (actual .collect ()))
295+
296+ from pyspark .sql import StructType , StructField , StringType
297+ schema = StructType ([StructField ("value" , StringType (), True )])
298+ actual = self .sqlCtx .load (tmpPath , "org.apache.spark.sql.json" , schema )
299+ self .assertTrue (sorted (df .select ("value" ).collect ()) == sorted (actual .collect ()))
300+
301+ df .save (tmpPath , "org.apache.spark.sql.json" , "overwrite" )
302+ actual = self .sqlCtx .load (tmpPath , "org.apache.spark.sql.json" )
303+ self .assertTrue (sorted (df .collect ()) == sorted (actual .collect ()))
304+
305+ df .save (dataSourceName = "org.apache.spark.sql.json" , mode = "overwrite" , path = tmpPath ,
306+ noUse = "this options will not be used in save." )
307+ actual = self .sqlCtx .load (dataSourceName = "org.apache.spark.sql.json" , path = tmpPath ,
308+ noUse = "this options will not be used in load." )
309+ self .assertTrue (sorted (df .collect ()) == sorted (actual .collect ()))
310+
311+ defaultDataSourceName = self .sqlCtx ._ssql_ctx .getConf ("spark.sql.sources.default" ,
312+ "org.apache.spark.sql.parquet" )
313+ self .sqlCtx .sql ("SET spark.sql.sources.default=org.apache.spark.sql.json" )
314+ actual = self .sqlCtx .load (path = tmpPath )
315+ self .assertTrue (sorted (df .collect ()) == sorted (actual .collect ()))
316+ self .sqlCtx .sql ("SET spark.sql.sources.default=" + defaultDataSourceName )
317+
318+ shutil .rmtree (tmpPath )
319+
288320 def test_help_command (self ):
289321 # Regression test for SPARK-5464
290322 rdd = self .sc .parallelize (['{"foo":"bar"}' , '{"foo":"baz"}' ])
@@ -294,6 +326,73 @@ def test_help_command(self):
294326 pydoc .render_doc (df .foo )
295327 pydoc .render_doc (df .take (1 ))
296328
329+ class HiveContextSQLTests (ReusedPySparkTestCase ):
330+
331+ @classmethod
332+ def setUpClass (cls ):
333+ ReusedPySparkTestCase .setUpClass ()
334+ cls .tempdir = tempfile .NamedTemporaryFile (delete = False )
335+ os .unlink (cls .tempdir .name )
336+ cls .sqlCtx = HiveContext (cls .sc )
337+ cls .testData = [Row (key = i , value = str (i )) for i in range (100 )]
338+ rdd = cls .sc .parallelize (cls .testData )
339+ cls .df = cls .sqlCtx .inferSchema (rdd )
340+
341+ @classmethod
342+ def tearDownClass (cls ):
343+ ReusedPySparkTestCase .tearDownClass ()
344+ shutil .rmtree (cls .tempdir .name , ignore_errors = True )
345+
346+ def test_save_and_load_table (self ):
347+ df = self .df
348+ tmpPath = tempfile .mkdtemp ()
349+ shutil .rmtree (tmpPath )
350+ df .saveAsTable ("savedJsonTable" , "org.apache.spark.sql.json" , "append" , path = tmpPath )
351+ actual = self .sqlCtx .createExternalTable ("externalJsonTable" , tmpPath ,
352+ "org.apache.spark.sql.json" )
353+ self .assertTrue (
354+ sorted (df .collect ()) ==
355+ sorted (self .sqlCtx .sql ("SELECT * FROM savedJsonTable" ).collect ()))
356+ self .assertTrue (
357+ sorted (df .collect ()) ==
358+ sorted (self .sqlCtx .sql ("SELECT * FROM externalJsonTable" ).collect ()))
359+ self .assertTrue (sorted (df .collect ()) == sorted (actual .collect ()))
360+ self .sqlCtx .sql ("DROP TABLE externalJsonTable" )
361+
362+ df .saveAsTable ("savedJsonTable" , "org.apache.spark.sql.json" , "overwrite" , path = tmpPath )
363+ from pyspark .sql import StructType , StructField , StringType
364+ schema = StructType ([StructField ("value" , StringType (), True )])
365+ actual = self .sqlCtx .createExternalTable ("externalJsonTable" ,
366+ dataSourceName = "org.apache.spark.sql.json" ,
367+ schema = schema , path = tmpPath ,
368+ noUse = "this options will not be used" )
369+ self .assertTrue (
370+ sorted (df .collect ()) ==
371+ sorted (self .sqlCtx .sql ("SELECT * FROM savedJsonTable" ).collect ()))
372+ self .assertTrue (
373+ sorted (df .select ("value" ).collect ()) ==
374+ sorted (self .sqlCtx .sql ("SELECT * FROM externalJsonTable" ).collect ()))
375+ self .assertTrue (sorted (df .select ("value" ).collect ()) == sorted (actual .collect ()))
376+ self .sqlCtx .sql ("DROP TABLE savedJsonTable" )
377+ self .sqlCtx .sql ("DROP TABLE externalJsonTable" )
378+
379+ defaultDataSourceName = self .sqlCtx ._ssql_ctx .getConf ("spark.sql.sources.default" ,
380+ "org.apache.spark.sql.parquet" )
381+ self .sqlCtx .sql ("SET spark.sql.sources.default=org.apache.spark.sql.json" )
382+ df .saveAsTable ("savedJsonTable" , path = tmpPath , mode = "overwrite" )
383+ actual = self .sqlCtx .createExternalTable ("externalJsonTable" , path = tmpPath )
384+ self .assertTrue (
385+ sorted (df .collect ()) ==
386+ sorted (self .sqlCtx .sql ("SELECT * FROM savedJsonTable" ).collect ()))
387+ self .assertTrue (
388+ sorted (df .collect ()) ==
389+ sorted (self .sqlCtx .sql ("SELECT * FROM externalJsonTable" ).collect ()))
390+ self .assertTrue (sorted (df .collect ()) == sorted (actual .collect ()))
391+ self .sqlCtx .sql ("DROP TABLE savedJsonTable" )
392+ self .sqlCtx .sql ("DROP TABLE externalJsonTable" )
393+ self .sqlCtx .sql ("SET spark.sql.sources.default=" + defaultDataSourceName )
394+
395+ shutil .rmtree (tmpPath )
297396
298397if __name__ == "__main__" :
299398 unittest .main ()
0 commit comments