Skip to content

Commit 2bf44ef

Browse files
committed
Python APIs.
1 parent c204967 commit 2bf44ef

File tree

2 files changed

+194
-5
lines changed

2 files changed

+194
-5
lines changed

python/pyspark/sql.py

Lines changed: 93 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1622,6 +1622,48 @@ def func(iterator):
16221622
df = self._ssql_ctx.jsonRDD(jrdd.rdd(), scala_datatype)
16231623
return DataFrame(df, self)
16241624

1625+
def load(self, path=None, dataSourceName=None, schema=None, **options):
1626+
"""Returns the dataset specified by the data source and a set of options
1627+
as a DataFrame. An optional schema can be applied as the schema of returned
1628+
DataFrame. If dataSourceName is not provided, the default data source configured
1629+
by spark.sql.sources.default will be used.
1630+
"""
1631+
if path is not None:
1632+
options["path"] = path
1633+
if dataSourceName is None:
1634+
dataSourceName = self._ssql_ctx.getConf("spark.sql.sources.default",
1635+
"org.apache.spark.sql.parquet")
1636+
joptions = MapConverter().convert(options,
1637+
self._sc._gateway._gateway_client)
1638+
if schema is None:
1639+
df = self._ssql_ctx.load(dataSourceName, joptions)
1640+
else:
1641+
scala_datatype = self._ssql_ctx.parseDataType(schema.json())
1642+
df = self._ssql_ctx.load(dataSourceName, scala_datatype, joptions)
1643+
return DataFrame(df, self)
1644+
1645+
def createExternalTable(self, tableName, path=None, dataSourceName=None,
1646+
schema=None, **options):
1647+
"""Creates an external table based on the given data source and a set of options and
1648+
returns the corresponding DataFrame.
1649+
If dataSourceName is not provided, the default data source configured
1650+
by spark.sql.sources.default will be used.
1651+
"""
1652+
if path is not None:
1653+
options["path"] = path
1654+
if dataSourceName is None:
1655+
dataSourceName = self._ssql_ctx.getConf("spark.sql.sources.default",
1656+
"org.apache.spark.sql.parquet")
1657+
joptions = MapConverter().convert(options,
1658+
self._sc._gateway._gateway_client)
1659+
if schema is None:
1660+
df = self._ssql_ctx.createExternalTable(tableName, dataSourceName, joptions)
1661+
else:
1662+
scala_datatype = self._ssql_ctx.parseDataType(schema.json())
1663+
df = self._ssql_ctx.createExternalTable(tableName, dataSourceName, scala_datatype,
1664+
joptions)
1665+
return DataFrame(df, self)
1666+
16251667
def sql(self, sqlQuery):
16261668
"""Return a L{DataFrame} representing the result of the given query.
16271669
@@ -1889,9 +1931,57 @@ def insertInto(self, tableName, overwrite=False):
18891931
"""
18901932
self._jdf.insertInto(tableName, overwrite)
18911933

1892-
def saveAsTable(self, tableName):
1893-
"""Creates a new table with the contents of this DataFrame."""
1894-
self._jdf.saveAsTable(tableName)
1934+
def saveAsTable(self, tableName, dataSourceName=None, mode="append", **options):
1935+
"""Creates a new table with the contents of this DataFrame based on the given data source
1936+
and a set of options. If a data source is not provided, the default data source configured
1937+
by spark.sql.sources.default will be used.
1938+
"""
1939+
if dataSourceName is None:
1940+
dataSourceName = self.sql_ctx._ssql_ctx.getConf("spark.sql.sources.default",
1941+
"org.apache.spark.sql.parquet")
1942+
jmode = self._sc._jvm.org.apache.spark.sql.sources.SaveMode.ErrorIfExists
1943+
mode = mode.lower()
1944+
if mode == "append":
1945+
jmode = self._sc._jvm.org.apache.spark.sql.sources.SaveMode.Append
1946+
elif mode == "overwrite":
1947+
jmode = self._sc._jvm.org.apache.spark.sql.sources.SaveMode.Overwrite
1948+
elif mode == "ignore":
1949+
jmode = self._sc._jvm.org.apache.spark.sql.sources.SaveMode.Ignore
1950+
elif mode == "error":
1951+
pass
1952+
else:
1953+
raise ValueError(
1954+
"Only 'append', 'overwrite', 'ignore', and 'error' are acceptable save mode.")
1955+
joptions = MapConverter().convert(options,
1956+
self.sql_ctx._sc._gateway._gateway_client)
1957+
self._jdf.saveAsTable(tableName, dataSourceName, jmode, joptions)
1958+
1959+
def save(self, path=None, dataSourceName=None, mode="append", **options):
1960+
"""Saves the contents of the DataFrame to a data source based on the given data source,
1961+
the given save mode, and a set of options. If a data source is not provided,
1962+
the default data source configured by spark.sql.sources.default will be used.
1963+
"""
1964+
if path is not None:
1965+
options["path"] = path
1966+
if dataSourceName is None:
1967+
dataSourceName = self.sql_ctx._ssql_ctx.getConf("spark.sql.sources.default",
1968+
"org.apache.spark.sql.parquet")
1969+
jmode = self._sc._jvm.org.apache.spark.sql.sources.SaveMode.ErrorIfExists
1970+
mode = mode.lower()
1971+
if mode == "append":
1972+
jmode = self._sc._jvm.org.apache.spark.sql.sources.SaveMode.Append
1973+
elif mode == "overwrite":
1974+
jmode = self._sc._jvm.org.apache.spark.sql.sources.SaveMode.Overwrite
1975+
elif mode == "ignore":
1976+
jmode = self._sc._jvm.org.apache.spark.sql.sources.SaveMode.Ignore
1977+
elif mode == "error":
1978+
pass
1979+
else:
1980+
raise ValueError(
1981+
"Only 'append', 'overwrite', 'ignore', and 'error' are acceptable save mode.")
1982+
joptions = MapConverter().convert(options,
1983+
self._sc._gateway._gateway_client)
1984+
self._jdf.save(dataSourceName, jmode, joptions)
18951985

18961986
def schema(self):
18971987
"""Returns the schema of this DataFrame (represented by

python/pyspark/sql_tests.py

Lines changed: 101 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,8 @@
3434
else:
3535
import unittest
3636

37-
from pyspark.sql import SQLContext, IntegerType, Row, ArrayType, StructType, StructField, \
38-
UserDefinedType, DoubleType
37+
from pyspark.sql import SQLContext, HiveContext, IntegerType, Row, ArrayType, StructType,\
38+
StructField, UserDefinedType, DoubleType
3939
from pyspark.tests import ReusedPySparkTestCase
4040

4141

@@ -285,6 +285,38 @@ def test_aggregator(self):
285285
self.assertTrue(95 < g.agg(Dsl.approxCountDistinct(df.key)).first()[0])
286286
self.assertEqual(100, g.agg(Dsl.countDistinct(df.value)).first()[0])
287287

288+
def test_save_and_load(self):
289+
df = self.df
290+
tmpPath = tempfile.mkdtemp()
291+
shutil.rmtree(tmpPath)
292+
df.save(tmpPath, "org.apache.spark.sql.json", "error")
293+
actual = self.sqlCtx.load(tmpPath, "org.apache.spark.sql.json")
294+
self.assertTrue(sorted(df.collect()) == sorted(actual.collect()))
295+
296+
from pyspark.sql import StructType, StructField, StringType
297+
schema = StructType([StructField("value", StringType(), True)])
298+
actual = self.sqlCtx.load(tmpPath, "org.apache.spark.sql.json", schema)
299+
self.assertTrue(sorted(df.select("value").collect()) == sorted(actual.collect()))
300+
301+
df.save(tmpPath, "org.apache.spark.sql.json", "overwrite")
302+
actual = self.sqlCtx.load(tmpPath, "org.apache.spark.sql.json")
303+
self.assertTrue(sorted(df.collect()) == sorted(actual.collect()))
304+
305+
df.save(dataSourceName="org.apache.spark.sql.json", mode="overwrite", path=tmpPath,
306+
noUse="this options will not be used in save.")
307+
actual = self.sqlCtx.load(dataSourceName="org.apache.spark.sql.json", path=tmpPath,
308+
noUse="this options will not be used in load.")
309+
self.assertTrue(sorted(df.collect()) == sorted(actual.collect()))
310+
311+
defaultDataSourceName = self.sqlCtx._ssql_ctx.getConf("spark.sql.sources.default",
312+
"org.apache.spark.sql.parquet")
313+
self.sqlCtx.sql("SET spark.sql.sources.default=org.apache.spark.sql.json")
314+
actual = self.sqlCtx.load(path=tmpPath)
315+
self.assertTrue(sorted(df.collect()) == sorted(actual.collect()))
316+
self.sqlCtx.sql("SET spark.sql.sources.default=" + defaultDataSourceName)
317+
318+
shutil.rmtree(tmpPath)
319+
288320
def test_help_command(self):
289321
# Regression test for SPARK-5464
290322
rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}'])
@@ -294,6 +326,73 @@ def test_help_command(self):
294326
pydoc.render_doc(df.foo)
295327
pydoc.render_doc(df.take(1))
296328

329+
class HiveContextSQLTests(ReusedPySparkTestCase):
330+
331+
@classmethod
332+
def setUpClass(cls):
333+
ReusedPySparkTestCase.setUpClass()
334+
cls.tempdir = tempfile.NamedTemporaryFile(delete=False)
335+
os.unlink(cls.tempdir.name)
336+
cls.sqlCtx = HiveContext(cls.sc)
337+
cls.testData = [Row(key=i, value=str(i)) for i in range(100)]
338+
rdd = cls.sc.parallelize(cls.testData)
339+
cls.df = cls.sqlCtx.inferSchema(rdd)
340+
341+
@classmethod
342+
def tearDownClass(cls):
343+
ReusedPySparkTestCase.tearDownClass()
344+
shutil.rmtree(cls.tempdir.name, ignore_errors=True)
345+
346+
def test_save_and_load_table(self):
347+
df = self.df
348+
tmpPath = tempfile.mkdtemp()
349+
shutil.rmtree(tmpPath)
350+
df.saveAsTable("savedJsonTable", "org.apache.spark.sql.json", "append", path=tmpPath)
351+
actual = self.sqlCtx.createExternalTable("externalJsonTable", tmpPath,
352+
"org.apache.spark.sql.json")
353+
self.assertTrue(
354+
sorted(df.collect()) ==
355+
sorted(self.sqlCtx.sql("SELECT * FROM savedJsonTable").collect()))
356+
self.assertTrue(
357+
sorted(df.collect()) ==
358+
sorted(self.sqlCtx.sql("SELECT * FROM externalJsonTable").collect()))
359+
self.assertTrue(sorted(df.collect()) == sorted(actual.collect()))
360+
self.sqlCtx.sql("DROP TABLE externalJsonTable")
361+
362+
df.saveAsTable("savedJsonTable", "org.apache.spark.sql.json", "overwrite", path=tmpPath)
363+
from pyspark.sql import StructType, StructField, StringType
364+
schema = StructType([StructField("value", StringType(), True)])
365+
actual = self.sqlCtx.createExternalTable("externalJsonTable",
366+
dataSourceName="org.apache.spark.sql.json",
367+
schema=schema, path=tmpPath,
368+
noUse="this options will not be used")
369+
self.assertTrue(
370+
sorted(df.collect()) ==
371+
sorted(self.sqlCtx.sql("SELECT * FROM savedJsonTable").collect()))
372+
self.assertTrue(
373+
sorted(df.select("value").collect()) ==
374+
sorted(self.sqlCtx.sql("SELECT * FROM externalJsonTable").collect()))
375+
self.assertTrue(sorted(df.select("value").collect()) == sorted(actual.collect()))
376+
self.sqlCtx.sql("DROP TABLE savedJsonTable")
377+
self.sqlCtx.sql("DROP TABLE externalJsonTable")
378+
379+
defaultDataSourceName = self.sqlCtx._ssql_ctx.getConf("spark.sql.sources.default",
380+
"org.apache.spark.sql.parquet")
381+
self.sqlCtx.sql("SET spark.sql.sources.default=org.apache.spark.sql.json")
382+
df.saveAsTable("savedJsonTable", path=tmpPath, mode="overwrite")
383+
actual = self.sqlCtx.createExternalTable("externalJsonTable", path=tmpPath)
384+
self.assertTrue(
385+
sorted(df.collect()) ==
386+
sorted(self.sqlCtx.sql("SELECT * FROM savedJsonTable").collect()))
387+
self.assertTrue(
388+
sorted(df.collect()) ==
389+
sorted(self.sqlCtx.sql("SELECT * FROM externalJsonTable").collect()))
390+
self.assertTrue(sorted(df.collect()) == sorted(actual.collect()))
391+
self.sqlCtx.sql("DROP TABLE savedJsonTable")
392+
self.sqlCtx.sql("DROP TABLE externalJsonTable")
393+
self.sqlCtx.sql("SET spark.sql.sources.default=" + defaultDataSourceName)
394+
395+
shutil.rmtree(tmpPath)
297396

298397
if __name__ == "__main__":
299398
unittest.main()

0 commit comments

Comments
 (0)