Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 20 additions & 99 deletions python/pyspark/sql/readwriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,84 +44,20 @@ def to_str(value):
return str(value)


class ReaderUtils(object):
class OptionUtils(object):

def _set_json_opts(self, schema, primitivesAsString, prefersDecimal,
allowComments, allowUnquotedFieldNames, allowSingleQuotes,
allowNumericLeadingZero, allowBackslashEscapingAnyCharacter,
mode, columnNameOfCorruptRecord):
def _set_opts(self, schema=None, **options):
"""
Set options based on the Json optional parameters
Set named options (filter out those the value is None)
"""
if schema is not None:
self.schema(schema)
if primitivesAsString is not None:
self.option("primitivesAsString", primitivesAsString)
if prefersDecimal is not None:
self.option("prefersDecimal", prefersDecimal)
if allowComments is not None:
self.option("allowComments", allowComments)
if allowUnquotedFieldNames is not None:
self.option("allowUnquotedFieldNames", allowUnquotedFieldNames)
if allowSingleQuotes is not None:
self.option("allowSingleQuotes", allowSingleQuotes)
if allowNumericLeadingZero is not None:
self.option("allowNumericLeadingZero", allowNumericLeadingZero)
if allowBackslashEscapingAnyCharacter is not None:
self.option("allowBackslashEscapingAnyCharacter", allowBackslashEscapingAnyCharacter)
if mode is not None:
self.option("mode", mode)
if columnNameOfCorruptRecord is not None:
self.option("columnNameOfCorruptRecord", columnNameOfCorruptRecord)

def _set_csv_opts(self, schema, sep, encoding, quote, escape,
comment, header, inferSchema, ignoreLeadingWhiteSpace,
ignoreTrailingWhiteSpace, nullValue, nanValue, positiveInf, negativeInf,
dateFormat, maxColumns, maxCharsPerColumn, maxMalformedLogPerPartition, mode):
"""
Set options based on the CSV optional parameters
"""
if schema is not None:
self.schema(schema)
if sep is not None:
self.option("sep", sep)
if encoding is not None:
self.option("encoding", encoding)
if quote is not None:
self.option("quote", quote)
if escape is not None:
self.option("escape", escape)
if comment is not None:
self.option("comment", comment)
if header is not None:
self.option("header", header)
if inferSchema is not None:
self.option("inferSchema", inferSchema)
if ignoreLeadingWhiteSpace is not None:
self.option("ignoreLeadingWhiteSpace", ignoreLeadingWhiteSpace)
if ignoreTrailingWhiteSpace is not None:
self.option("ignoreTrailingWhiteSpace", ignoreTrailingWhiteSpace)
if nullValue is not None:
self.option("nullValue", nullValue)
if nanValue is not None:
self.option("nanValue", nanValue)
if positiveInf is not None:
self.option("positiveInf", positiveInf)
if negativeInf is not None:
self.option("negativeInf", negativeInf)
if dateFormat is not None:
self.option("dateFormat", dateFormat)
if maxColumns is not None:
self.option("maxColumns", maxColumns)
if maxCharsPerColumn is not None:
self.option("maxCharsPerColumn", maxCharsPerColumn)
if maxMalformedLogPerPartition is not None:
self.option("maxMalformedLogPerPartition", maxMalformedLogPerPartition)
if mode is not None:
self.option("mode", mode)


class DataFrameReader(ReaderUtils):
for k, v in options.items():
if v is not None:
self.option(k, v)


class DataFrameReader(OptionUtils):
"""
Interface used to load a :class:`DataFrame` from external storage systems
(e.g. file systems, key-value stores, etc). Use :func:`spark.read`
Expand Down Expand Up @@ -270,7 +206,7 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None,
[('age', 'bigint'), ('name', 'string')]

"""
self._set_json_opts(
self._set_opts(
schema=schema, primitivesAsString=primitivesAsString, prefersDecimal=prefersDecimal,
allowComments=allowComments, allowUnquotedFieldNames=allowUnquotedFieldNames,
allowSingleQuotes=allowSingleQuotes, allowNumericLeadingZero=allowNumericLeadingZero,
Expand Down Expand Up @@ -413,7 +349,7 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non
>>> df.dtypes
[('_c0', 'string'), ('_c1', 'string')]
"""
self._set_csv_opts(
self._set_opts(
schema=schema, sep=sep, encoding=encoding, quote=quote, escape=escape, comment=comment,
header=header, inferSchema=inferSchema, ignoreLeadingWhiteSpace=ignoreLeadingWhiteSpace,
ignoreTrailingWhiteSpace=ignoreTrailingWhiteSpace, nullValue=nullValue,
Expand Down Expand Up @@ -484,7 +420,7 @@ def jdbc(self, url, table, column=None, lowerBound=None, upperBound=None, numPar
return self._df(self._jreader.jdbc(url, table, jprop))


class DataFrameWriter(object):
class DataFrameWriter(OptionUtils):
"""
Interface used to write a :class:`DataFrame` to external storage systems
(e.g. file systems, key-value stores, etc). Use :func:`DataFrame.write`
Expand Down Expand Up @@ -649,8 +585,7 @@ def json(self, path, mode=None, compression=None):
>>> df.write.json(os.path.join(tempfile.mkdtemp(), 'data'))
"""
self.mode(mode)
if compression is not None:
self.option("compression", compression)
self._set_opts(compression=compression)
self._jwrite.json(path)

@since(1.4)
Expand All @@ -676,8 +611,7 @@ def parquet(self, path, mode=None, partitionBy=None, compression=None):
self.mode(mode)
if partitionBy is not None:
self.partitionBy(partitionBy)
if compression is not None:
self.option("compression", compression)
self._set_opts(compression=compression)
self._jwrite.parquet(path)

@since(1.6)
Expand All @@ -692,8 +626,7 @@ def text(self, path, compression=None):
The DataFrame must have only one column that is of string type.
Each row becomes a new line in the output file.
"""
if compression is not None:
self.option("compression", compression)
self._set_opts(compression=compression)
self._jwrite.text(path)

@since(2.0)
Expand Down Expand Up @@ -731,20 +664,8 @@ def csv(self, path, mode=None, compression=None, sep=None, quote=None, escape=No
>>> df.write.csv(os.path.join(tempfile.mkdtemp(), 'data'))
"""
self.mode(mode)
if compression is not None:
self.option("compression", compression)
if sep is not None:
self.option("sep", sep)
if quote is not None:
self.option("quote", quote)
if escape is not None:
self.option("escape", escape)
if header is not None:
self.option("header", header)
if nullValue is not None:
self.option("nullValue", nullValue)
if escapeQuotes is not None:
self.option("escapeQuotes", nullValue)
self._set_opts(compression=compression, sep=sep, quote=quote, escape=escape, header=header,
nullValue=nullValue, escapeQuotes=escapeQuotes)
self._jwrite.csv(path)

@since(1.5)
Expand Down Expand Up @@ -803,7 +724,7 @@ def jdbc(self, url, table, mode=None, properties=None):
self._jwrite.mode(mode).jdbc(url, table, jprop)


class DataStreamReader(ReaderUtils):
class DataStreamReader(OptionUtils):
"""
Interface used to load a streaming :class:`DataFrame` from external storage systems
(e.g. file systems, key-value stores, etc). Use :func:`spark.readStream`
Expand Down Expand Up @@ -965,7 +886,7 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None,
>>> json_sdf.schema == sdf_schema
True
"""
self._set_json_opts(
self._set_opts(
schema=schema, primitivesAsString=primitivesAsString, prefersDecimal=prefersDecimal,
allowComments=allowComments, allowUnquotedFieldNames=allowUnquotedFieldNames,
allowSingleQuotes=allowSingleQuotes, allowNumericLeadingZero=allowNumericLeadingZero,
Expand Down Expand Up @@ -1095,7 +1016,7 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non
>>> csv_sdf.schema == sdf_schema
True
"""
self._set_csv_opts(
self._set_opts(
schema=schema, sep=sep, encoding=encoding, quote=quote, escape=escape, comment=comment,
header=header, inferSchema=inferSchema, ignoreLeadingWhiteSpace=ignoreLeadingWhiteSpace,
ignoreTrailingWhiteSpace=ignoreTrailingWhiteSpace, nullValue=nullValue,
Expand Down