Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions python/pyspark/sql/readwriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -859,7 +859,7 @@ def text(self, path, compression=None, lineSep=None):
def csv(self, path, mode=None, compression=None, sep=None, quote=None, escape=None,
header=None, nullValue=None, escapeQuotes=None, quoteAll=None, dateFormat=None,
timestampFormat=None, ignoreLeadingWhiteSpace=None, ignoreTrailingWhiteSpace=None,
charToEscapeQuoteEscaping=None):
charToEscapeQuoteEscaping=None, encoding=None):
"""Saves the content of the :class:`DataFrame` in CSV format at the specified path.

:param path: the path in any Hadoop supported file system
Expand Down Expand Up @@ -909,6 +909,8 @@ def csv(self, path, mode=None, compression=None, sep=None, quote=None, escape=No
the quote character. If None is set, the default value is
escape character when escape and quote characters are
different, ``\0`` otherwise..
:param encoding: sets the encoding (charset) of saved csv files. If None is set,
the default UTF-8 charset will be used.

>>> df.write.csv(os.path.join(tempfile.mkdtemp(), 'data'))
"""
Expand All @@ -918,7 +920,8 @@ def csv(self, path, mode=None, compression=None, sep=None, quote=None, escape=No
dateFormat=dateFormat, timestampFormat=timestampFormat,
ignoreLeadingWhiteSpace=ignoreLeadingWhiteSpace,
ignoreTrailingWhiteSpace=ignoreTrailingWhiteSpace,
charToEscapeQuoteEscaping=charToEscapeQuoteEscaping)
charToEscapeQuoteEscaping=charToEscapeQuoteEscaping,
encoding=encoding)
self._jwrite.csv(path)

@since(1.5)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -629,6 +629,8 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {
* enclosed in quotes. Default is to only escape values containing a quote character.</li>
* <li>`header` (default `false`): writes the names of columns as the first line.</li>
* <li>`nullValue` (default empty string): sets the string representation of a null value.</li>
* <li>`encoding` (by default it is not set): specifies encoding (charset) of saved csv
* files. If it is not set, the UTF-8 charset will be used.</li>
* <li>`compression` (default `null`): compression codec to use when saving to file. This can be
* one of the known case-insensitive shorten names (`none`, `bzip2`, `gzip`, `lz4`,
* `snappy` and `deflate`). </li>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

package org.apache.spark.sql.execution.datasources.csv

import java.nio.charset.Charset

import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{FileStatus, Path}
import org.apache.hadoop.mapreduce._
Expand Down Expand Up @@ -168,7 +170,9 @@ private[csv] class CsvOutputWriter(
context: TaskAttemptContext,
params: CSVOptions) extends OutputWriter with Logging {

private val writer = CodecStreams.createOutputStreamWriter(context, new Path(path))
private val charset = Charset.forName(params.charset)

private val writer = CodecStreams.createOutputStreamWriter(context, new Path(path), charset)

private val gen = new UnivocityGenerator(dataSchema, writer, params)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,14 @@
package org.apache.spark.sql.execution.datasources.csv

import java.io.File
import java.nio.charset.UnsupportedCharsetException
import java.nio.charset.{Charset, UnsupportedCharsetException}
import java.nio.file.Files
import java.sql.{Date, Timestamp}
import java.text.SimpleDateFormat
import java.util.Locale

import scala.collection.JavaConverters._
import scala.util.Properties

import org.apache.commons.lang3.time.FastDateFormat
import org.apache.hadoop.io.SequenceFile.CompressionType
Expand Down Expand Up @@ -514,6 +516,41 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils with Te
}
}

test("SPARK-19018: Save csv with custom charset") {

// scalastyle:off nonascii
val content = "µß áâä ÁÂÄ"
// scalastyle:on nonascii

Seq("iso-8859-1", "utf-8", "utf-16", "utf-32", "windows-1250").foreach { encoding =>
withTempPath { path =>
val csvDir = new File(path, "csv")
Seq(content).toDF().write
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: .write.repartition(1) to make sure we write only one file

.option("encoding", encoding)
.csv(csvDir.getCanonicalPath)

csvDir.listFiles().filter(_.getName.endsWith("csv")).foreach({ csvFile =>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

val readback = Files.readAllBytes(csvFile.toPath)
val expected = (content + Properties.lineSeparator).getBytes(Charset.forName(encoding))
assert(readback === expected)
})
}
}
}

test("SPARK-19018: error handling for unsupported charsets") {
val exception = intercept[SparkException] {
withTempPath { path =>
val csvDir = new File(path, "csv").getCanonicalPath
Seq("a,A,c,A,b,B").toDF().write
.option("encoding", "1-9588-osi")
.csv(csvDir)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: you could use directly path.getCanonicalPath

}
}

assert(exception.getCause.getMessage.contains("1-9588-osi"))
}

test("commented lines in CSV data") {
Seq("false", "true").foreach { multiLine =>

Expand Down