Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
5390743
Add more save APIs to DataFrame.
yhuai Feb 5, 2015
5ffc372
Add more load APIs to SQLContext.
yhuai Feb 5, 2015
43bae01
Remove createTable from HiveContext.
yhuai Feb 5, 2015
e6a0b77
Update test.
yhuai Feb 5, 2015
2a6213a
Update API names.
yhuai Feb 5, 2015
af9e9b3
DDL and write support API followup.
yhuai Feb 6, 2015
ed4e1b4
Merge remote-tracking branch 'upstream/master' into writeSupportFollowup
yhuai Feb 6, 2015
b1e9b1b
Format.
yhuai Feb 6, 2015
e702386
Apache header.
yhuai Feb 7, 2015
f2f33ef
Fix test.
yhuai Feb 8, 2015
6dfd386
Add java test.
yhuai Feb 8, 2015
7db95ff
Merge remote-tracking branch 'upstream/master' into writeSupportFollowup
yhuai Feb 8, 2015
cf5703d
Add checkAnswer to Java tests.
yhuai Feb 9, 2015
e04d908
Move import and add (Scala-specific) to scala APIs.
yhuai Feb 9, 2015
77d89dc
Update doc.
yhuai Feb 9, 2015
4679665
Remove duplicate rule.
yhuai Feb 9, 2015
99950a2
Use Java enum for SaveMode.
yhuai Feb 9, 2015
c2be775
Merge remote-tracking branch 'upstream/master' into writeSupportFollowup
yhuai Feb 9, 2015
9b6e570
Update doc.
yhuai Feb 9, 2015
9ff97d8
Add SaveMode to saveAsTable.
yhuai Feb 10, 2015
a10223d
Merge remote-tracking branch 'upstream/master' into writeSupportFollowup
yhuai Feb 10, 2015
c204967
Format
yhuai Feb 10, 2015
2bf44ef
Python APIs.
yhuai Feb 10, 2015
98e7cdb
Python style.
yhuai Feb 10, 2015
0832ce4
Fix test.
yhuai Feb 10, 2015
3abc215
Merge remote-tracking branch 'upstream/master' into writeSupportFollowup
yhuai Feb 10, 2015
4c76d78
Simplify APIs.
yhuai Feb 10, 2015
d91ecb8
Fix test.
yhuai Feb 10, 2015
22cfa70
Merge remote-tracking branch 'upstream/master' into writeSupportFollowup
yhuai Feb 10, 2015
d1c12d3
No need to delete the duplicate rule since it has been removed in mas…
yhuai Feb 10, 2015
cbc717f
Rename dataSourceName to source.
yhuai Feb 10, 2015
92b6659
Python doc and other minor updates.
yhuai Feb 10, 2015
609129c
Doc format.
yhuai Feb 10, 2015
ae4649e
Fix Python test.
yhuai Feb 10, 2015
537e28f
Correctly clean up temp data.
yhuai Feb 10, 2015
2091fcd
Merge remote-tracking branch 'upstream/master' into writeSupportFollowup
yhuai Feb 10, 2015
2306f93
Style.
yhuai Feb 10, 2015
225ff71
Use Scala TestHiveContext to initialize the Python HiveContext in Pyt…
yhuai Feb 10, 2015
f3a96f7
davies's comments.
yhuai Feb 10, 2015
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 68 additions & 0 deletions python/pyspark/sql/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from itertools import imap

from py4j.protocol import Py4JError
from py4j.java_collections import MapConverter

from pyspark.rdd import _prepare_for_python_RDD
from pyspark.serializers import AutoBatchedSerializer, PickleSerializer
Expand Down Expand Up @@ -87,6 +88,18 @@ def _ssql_ctx(self):
self._scala_SQLContext = self._jvm.SQLContext(self._jsc.sc())
return self._scala_SQLContext

def setConf(self, key, value):
"""Sets the given Spark SQL configuration property.
"""
self._ssql_ctx.setConf(key, value)

def getConf(self, key, defaultValue):
"""Returns the value of Spark SQL configuration property for the given key.

If the key is not set, returns defaultValue.
"""
return self._ssql_ctx.getConf(key, defaultValue)

def registerFunction(self, name, f, returnType=StringType()):
"""Registers a lambda function as a UDF so it can be used in SQL statements.

Expand Down Expand Up @@ -455,6 +468,61 @@ def func(iterator):
df = self._ssql_ctx.jsonRDD(jrdd.rdd(), scala_datatype)
return DataFrame(df, self)

def load(self, path=None, source=None, schema=None, **options):
"""Returns the dataset in a data source as a DataFrame.

The data source is specified by the `source` and a set of `options`.
If `source` is not specified, the default data source configured by
spark.sql.sources.default will be used.

Optionally, a schema can be provided as the schema of the returned DataFrame.
"""
if path is not None:
options["path"] = path
if source is None:
source = self.getConf("spark.sql.sources.default",
"org.apache.spark.sql.parquet")
joptions = MapConverter().convert(options,
self._sc._gateway._gateway_client)
if schema is None:
df = self._ssql_ctx.load(source, joptions)
else:
if not isinstance(schema, StructType):
raise TypeError("schema should be StructType")
scala_datatype = self._ssql_ctx.parseDataType(schema.json())
df = self._ssql_ctx.load(source, scala_datatype, joptions)
return DataFrame(df, self)

def createExternalTable(self, tableName, path=None, source=None,
schema=None, **options):
"""Creates an external table based on the dataset in a data source.

It returns the DataFrame associated with the external table.

The data source is specified by the `source` and a set of `options`.
If `source` is not specified, the default data source configured by
spark.sql.sources.default will be used.

Optionally, a schema can be provided as the schema of the returned DataFrame and
created external table.
"""
if path is not None:
options["path"] = path
if source is None:
source = self.getConf("spark.sql.sources.default",
"org.apache.spark.sql.parquet")
joptions = MapConverter().convert(options,
self._sc._gateway._gateway_client)
if schema is None:
df = self._ssql_ctx.createExternalTable(tableName, source, joptions)
else:
if not isinstance(schema, StructType):
raise TypeError("schema should be StructType")
scala_datatype = self._ssql_ctx.parseDataType(schema.json())
df = self._ssql_ctx.createExternalTable(tableName, source, scala_datatype,
joptions)
return DataFrame(df, self)

def sql(self, sqlQuery):
"""Return a L{DataFrame} representing the result of the given query.

Expand Down
72 changes: 69 additions & 3 deletions python/pyspark/sql/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,9 +146,75 @@ def insertInto(self, tableName, overwrite=False):
"""
self._jdf.insertInto(tableName, overwrite)

def saveAsTable(self, tableName):
"""Creates a new table with the contents of this DataFrame."""
self._jdf.saveAsTable(tableName)
def _java_save_mode(self, mode):
"""Returns the Java save mode based on the Python save mode represented by a string.
"""
jSaveMode = self._sc._jvm.org.apache.spark.sql.sources.SaveMode
jmode = jSaveMode.ErrorIfExists
mode = mode.lower()
if mode == "append":
jmode = jSaveMode.Append
elif mode == "overwrite":
jmode = jSaveMode.Overwrite
elif mode == "ignore":
jmode = jSaveMode.Ignore
elif mode == "error":
pass
else:
raise ValueError(
"Only 'append', 'overwrite', 'ignore', and 'error' are acceptable save mode.")
return jmode

def saveAsTable(self, tableName, source=None, mode="append", **options):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think tableName could be just name.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure if we can do it. tableName has been used at lots of places.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I saw it's only used int insertInto() and saveAsTable(), and registerAsTable() use name.

It's not a keyword argument, so it's safe to change the name.

"""Saves the contents of the DataFrame to a data source as a table.

The data source is specified by the `source` and a set of `options`.
If `source` is not specified, the default data source configured by
spark.sql.sources.default will be used.

Additionally, mode is used to specify the behavior of the saveAsTable operation when
table already exists in the data source. There are four modes:

* append: Contents of this DataFrame are expected to be appended to existing table.
* overwrite: Data in the existing table is expected to be overwritten by the contents of \
this DataFrame.
* error: An exception is expected to be thrown.
* ignore: The save operation is expected to not save the contents of the DataFrame and \
to not change the existing table.
"""
if source is None:
source = self.sql_ctx.getConf("spark.sql.sources.default",
"org.apache.spark.sql.parquet")
jmode = self._java_save_mode(mode)
joptions = MapConverter().convert(options,
self.sql_ctx._sc._gateway._gateway_client)
self._jdf.saveAsTable(tableName, source, jmode, joptions)

def save(self, path=None, source=None, mode="append", **options):
"""Saves the contents of the DataFrame to a data source.

The data source is specified by the `source` and a set of `options`.
If `source` is not specified, the default data source configured by
spark.sql.sources.default will be used.

Additionally, mode is used to specify the behavior of the save operation when
data already exists in the data source. There are four modes:

* append: Contents of this DataFrame are expected to be appended to existing data.
* overwrite: Existing data is expected to be overwritten by the contents of this DataFrame.
* error: An exception is expected to be thrown.
* ignore: The save operation is expected to not save the contents of the DataFrame and \
to not change the existing data.
"""
if path is not None:
options["path"] = path
if source is None:
source = self.sql_ctx.getConf("spark.sql.sources.default",
"org.apache.spark.sql.parquet")
jmode = self._java_save_mode(mode)
joptions = MapConverter().convert(options,
self._sc._gateway._gateway_client)
self._jdf.save(source, jmode, joptions)

def schema(self):
"""Returns the schema of this DataFrame (represented by
Expand Down
107 changes: 104 additions & 3 deletions python/pyspark/sql/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,9 @@
else:
import unittest


from pyspark.sql import SQLContext, Column
from pyspark.sql import SQLContext, HiveContext, Column
from pyspark.sql.types import IntegerType, Row, ArrayType, StructType, StructField, \
UserDefinedType, DoubleType, LongType
UserDefinedType, DoubleType, LongType, StringType
from pyspark.tests import ReusedPySparkTestCase


Expand Down Expand Up @@ -286,6 +285,37 @@ def test_aggregator(self):
self.assertTrue(95 < g.agg(Dsl.approxCountDistinct(df.key)).first()[0])
self.assertEqual(100, g.agg(Dsl.countDistinct(df.value)).first()[0])

def test_save_and_load(self):
df = self.df
tmpPath = tempfile.mkdtemp()
shutil.rmtree(tmpPath)
df.save(tmpPath, "org.apache.spark.sql.json", "error")
actual = self.sqlCtx.load(tmpPath, "org.apache.spark.sql.json")
self.assertTrue(sorted(df.collect()) == sorted(actual.collect()))

schema = StructType([StructField("value", StringType(), True)])
actual = self.sqlCtx.load(tmpPath, "org.apache.spark.sql.json", schema)
self.assertTrue(sorted(df.select("value").collect()) == sorted(actual.collect()))

df.save(tmpPath, "org.apache.spark.sql.json", "overwrite")
actual = self.sqlCtx.load(tmpPath, "org.apache.spark.sql.json")
self.assertTrue(sorted(df.collect()) == sorted(actual.collect()))

df.save(source="org.apache.spark.sql.json", mode="overwrite", path=tmpPath,
noUse="this options will not be used in save.")
actual = self.sqlCtx.load(source="org.apache.spark.sql.json", path=tmpPath,
noUse="this options will not be used in load.")
self.assertTrue(sorted(df.collect()) == sorted(actual.collect()))

defaultDataSourceName = self.sqlCtx.getConf("spark.sql.sources.default",
"org.apache.spark.sql.parquet")
self.sqlCtx.sql("SET spark.sql.sources.default=org.apache.spark.sql.json")
actual = self.sqlCtx.load(path=tmpPath)
self.assertTrue(sorted(df.collect()) == sorted(actual.collect()))
self.sqlCtx.sql("SET spark.sql.sources.default=" + defaultDataSourceName)

shutil.rmtree(tmpPath)

def test_help_command(self):
# Regression test for SPARK-5464
rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}'])
Expand All @@ -296,5 +326,76 @@ def test_help_command(self):
pydoc.render_doc(df.take(1))


class HiveContextSQLTests(ReusedPySparkTestCase):

@classmethod
def setUpClass(cls):
ReusedPySparkTestCase.setUpClass()
cls.tempdir = tempfile.NamedTemporaryFile(delete=False)
os.unlink(cls.tempdir.name)
print "type", type(cls.sc)
print "type", type(cls.sc._jsc)
_scala_HiveContext =\
cls.sc._jvm.org.apache.spark.sql.hive.test.TestHiveContext(cls.sc._jsc.sc())
cls.sqlCtx = HiveContext(cls.sc, _scala_HiveContext)
cls.testData = [Row(key=i, value=str(i)) for i in range(100)]
rdd = cls.sc.parallelize(cls.testData)
cls.df = cls.sqlCtx.inferSchema(rdd)

@classmethod
def tearDownClass(cls):
ReusedPySparkTestCase.tearDownClass()
shutil.rmtree(cls.tempdir.name, ignore_errors=True)

def test_save_and_load_table(self):
df = self.df
tmpPath = tempfile.mkdtemp()
shutil.rmtree(tmpPath)
df.saveAsTable("savedJsonTable", "org.apache.spark.sql.json", "append", path=tmpPath)
actual = self.sqlCtx.createExternalTable("externalJsonTable", tmpPath,
"org.apache.spark.sql.json")
self.assertTrue(
sorted(df.collect()) ==
sorted(self.sqlCtx.sql("SELECT * FROM savedJsonTable").collect()))
self.assertTrue(
sorted(df.collect()) ==
sorted(self.sqlCtx.sql("SELECT * FROM externalJsonTable").collect()))
self.assertTrue(sorted(df.collect()) == sorted(actual.collect()))
self.sqlCtx.sql("DROP TABLE externalJsonTable")

df.saveAsTable("savedJsonTable", "org.apache.spark.sql.json", "overwrite", path=tmpPath)
schema = StructType([StructField("value", StringType(), True)])
actual = self.sqlCtx.createExternalTable("externalJsonTable",
source="org.apache.spark.sql.json",
schema=schema, path=tmpPath,
noUse="this options will not be used")
self.assertTrue(
sorted(df.collect()) ==
sorted(self.sqlCtx.sql("SELECT * FROM savedJsonTable").collect()))
self.assertTrue(
sorted(df.select("value").collect()) ==
sorted(self.sqlCtx.sql("SELECT * FROM externalJsonTable").collect()))
self.assertTrue(sorted(df.select("value").collect()) == sorted(actual.collect()))
self.sqlCtx.sql("DROP TABLE savedJsonTable")
self.sqlCtx.sql("DROP TABLE externalJsonTable")

defaultDataSourceName = self.sqlCtx.getConf("spark.sql.sources.default",
"org.apache.spark.sql.parquet")
self.sqlCtx.sql("SET spark.sql.sources.default=org.apache.spark.sql.json")
df.saveAsTable("savedJsonTable", path=tmpPath, mode="overwrite")
actual = self.sqlCtx.createExternalTable("externalJsonTable", path=tmpPath)
self.assertTrue(
sorted(df.collect()) ==
sorted(self.sqlCtx.sql("SELECT * FROM savedJsonTable").collect()))
self.assertTrue(
sorted(df.collect()) ==
sorted(self.sqlCtx.sql("SELECT * FROM externalJsonTable").collect()))
self.assertTrue(sorted(df.collect()) == sorted(actual.collect()))
self.sqlCtx.sql("DROP TABLE savedJsonTable")
self.sqlCtx.sql("DROP TABLE externalJsonTable")
self.sqlCtx.sql("SET spark.sql.sources.default=" + defaultDataSourceName)

shutil.rmtree(tmpPath)

if __name__ == "__main__":
unittest.main()
45 changes: 45 additions & 0 deletions sql/core/src/main/java/org/apache/spark/sql/sources/SaveMode.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.sources;

/**
* SaveMode is used to specify the expected behavior of saving a DataFrame to a data source.
*/
public enum SaveMode {
/**
* Append mode means that when saving a DataFrame to a data source, if data/table already exists,
* contents of the DataFrame are expected to be appended to existing data.
*/
Append,
/**
* Overwrite mode means that when saving a DataFrame to a data source,
* if data/table already exists, existing data is expected to be overwritten by the contents of
* the DataFrame.
*/
Overwrite,
/**
* ErrorIfExists mode means that when saving a DataFrame to a data source, if data already exists,
* an exception is expected to be thrown.
*/
ErrorIfExists,
/**
* Ignore mode means that when saving a DataFrame to a data source, if data already exists,
* the save operation is expected to not save the contents of the DataFrame and to not
* change the existing data.
*/
Ignore
}
Loading