diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py
index 3efe2adb6e2a..98b2cd996840 100644
--- a/python/pyspark/sql/readwriter.py
+++ b/python/pyspark/sql/readwriter.py
@@ -859,7 +859,7 @@ def text(self, path, compression=None, lineSep=None):
def csv(self, path, mode=None, compression=None, sep=None, quote=None, escape=None,
header=None, nullValue=None, escapeQuotes=None, quoteAll=None, dateFormat=None,
timestampFormat=None, ignoreLeadingWhiteSpace=None, ignoreTrailingWhiteSpace=None,
- charToEscapeQuoteEscaping=None):
+ charToEscapeQuoteEscaping=None, encoding=None):
"""Saves the content of the :class:`DataFrame` in CSV format at the specified path.
:param path: the path in any Hadoop supported file system
@@ -909,6 +909,8 @@ def csv(self, path, mode=None, compression=None, sep=None, quote=None, escape=No
the quote character. If None is set, the default value is
escape character when escape and quote characters are
different, ``\0`` otherwise..
+ :param encoding: sets the encoding (charset) of saved csv files. If None is set,
+ the default UTF-8 charset will be used.
>>> df.write.csv(os.path.join(tempfile.mkdtemp(), 'data'))
"""
@@ -918,7 +920,8 @@ def csv(self, path, mode=None, compression=None, sep=None, quote=None, escape=No
dateFormat=dateFormat, timestampFormat=timestampFormat,
ignoreLeadingWhiteSpace=ignoreLeadingWhiteSpace,
ignoreTrailingWhiteSpace=ignoreTrailingWhiteSpace,
- charToEscapeQuoteEscaping=charToEscapeQuoteEscaping)
+ charToEscapeQuoteEscaping=charToEscapeQuoteEscaping,
+ encoding=encoding)
self._jwrite.csv(path)
@since(1.5)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
index 90bea2d676e2..b9fa43f1f9fb 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
@@ -629,6 +629,8 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {
* enclosed in quotes. Default is to only escape values containing a quote character.
*
`header` (default `false`): writes the names of columns as the first line.
* `nullValue` (default empty string): sets the string representation of a null value.
+ * `encoding` (by default it is not set): specifies encoding (charset) of saved csv
+ * files. If it is not set, the UTF-8 charset will be used.
* `compression` (default `null`): compression codec to use when saving to file. This can be
* one of the known case-insensitive shorten names (`none`, `bzip2`, `gzip`, `lz4`,
* `snappy` and `deflate`).
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala
index aeb40e5a4131..d59b9820bdee 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala
@@ -17,6 +17,8 @@
package org.apache.spark.sql.execution.datasources.csv
+import java.nio.charset.Charset
+
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{FileStatus, Path}
import org.apache.hadoop.mapreduce._
@@ -168,7 +170,9 @@ private[csv] class CsvOutputWriter(
context: TaskAttemptContext,
params: CSVOptions) extends OutputWriter with Logging {
- private val writer = CodecStreams.createOutputStreamWriter(context, new Path(path))
+ private val charset = Charset.forName(params.charset)
+
+ private val writer = CodecStreams.createOutputStreamWriter(context, new Path(path), charset)
private val gen = new UnivocityGenerator(dataSchema, writer, params)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala
index 63cc5985040c..456b4535a0dc 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala
@@ -18,12 +18,14 @@
package org.apache.spark.sql.execution.datasources.csv
import java.io.File
-import java.nio.charset.UnsupportedCharsetException
+import java.nio.charset.{Charset, UnsupportedCharsetException}
+import java.nio.file.Files
import java.sql.{Date, Timestamp}
import java.text.SimpleDateFormat
import java.util.Locale
import scala.collection.JavaConverters._
+import scala.util.Properties
import org.apache.commons.lang3.time.FastDateFormat
import org.apache.hadoop.io.SequenceFile.CompressionType
@@ -514,6 +516,41 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils with Te
}
}
+ test("SPARK-19018: Save csv with custom charset") {
+
+ // scalastyle:off nonascii
+ val content = "µß áâä ÁÂÄ"
+ // scalastyle:on nonascii
+
+ Seq("iso-8859-1", "utf-8", "utf-16", "utf-32", "windows-1250").foreach { encoding =>
+ withTempPath { path =>
+ val csvDir = new File(path, "csv")
+ Seq(content).toDF().write
+ .option("encoding", encoding)
+ .csv(csvDir.getCanonicalPath)
+
+ csvDir.listFiles().filter(_.getName.endsWith("csv")).foreach({ csvFile =>
+ val readback = Files.readAllBytes(csvFile.toPath)
+ val expected = (content + Properties.lineSeparator).getBytes(Charset.forName(encoding))
+ assert(readback === expected)
+ })
+ }
+ }
+ }
+
+ test("SPARK-19018: error handling for unsupported charsets") {
+ val exception = intercept[SparkException] {
+ withTempPath { path =>
+ val csvDir = new File(path, "csv").getCanonicalPath
+ Seq("a,A,c,A,b,B").toDF().write
+ .option("encoding", "1-9588-osi")
+ .csv(csvDir)
+ }
+ }
+
+ assert(exception.getCause.getMessage.contains("1-9588-osi"))
+ }
+
test("commented lines in CSV data") {
Seq("false", "true").foreach { multiLine =>