diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala index c081772116f8..47eeb70e0042 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala @@ -97,32 +97,16 @@ private[sql] class JSONOptions( sep } + protected def checkedEncoding(enc: String): String = enc + /** * Standard encoding (charset) name. For example UTF-8, UTF-16LE and UTF-32BE. - * If the encoding is not specified (None), it will be detected automatically - * when the multiLine option is set to `true`. + * If the encoding is not specified (None) in read, it will be detected automatically + * when the multiLine option is set to `true`. If encoding is not specified in write, + * UTF-8 is used by default. */ val encoding: Option[String] = parameters.get("encoding") - .orElse(parameters.get("charset")).map { enc => - // The following encodings are not supported in per-line mode (multiline is false) - // because they cause some problems in reading files with BOM which is supposed to - // present in the files with such encodings. After splitting input files by lines, - // only the first lines will have the BOM which leads to impossibility for reading - // the rest lines. Besides of that, the lineSep option must have the BOM in such - // encodings which can never present between lines. - val blacklist = Seq(Charset.forName("UTF-16"), Charset.forName("UTF-32")) - val isBlacklisted = blacklist.contains(Charset.forName(enc)) - require(multiLine || !isBlacklisted, - s"""The $enc encoding in the blacklist is not allowed when multiLine is disabled. - |Blacklist: ${blacklist.mkString(", ")}""".stripMargin) - - val isLineSepRequired = - multiLine || Charset.forName(enc) == StandardCharsets.UTF_8 || lineSeparator.nonEmpty - - require(isLineSepRequired, s"The lineSep option must be specified for the $enc encoding") - - enc - } + .orElse(parameters.get("charset")).map(checkedEncoding) val lineSeparatorInRead: Option[Array[Byte]] = lineSeparator.map { lineSep => lineSep.getBytes(encoding.getOrElse("UTF-8")) @@ -141,3 +125,46 @@ private[sql] class JSONOptions( factory.configure(JsonParser.Feature.ALLOW_UNQUOTED_CONTROL_CHARS, allowUnquotedControlChars) } } + +private[sql] class JSONOptionsInRead( + @transient override val parameters: CaseInsensitiveMap[String], + defaultTimeZoneId: String, + defaultColumnNameOfCorruptRecord: String) + extends JSONOptions(parameters, defaultTimeZoneId, defaultColumnNameOfCorruptRecord) { + + def this( + parameters: Map[String, String], + defaultTimeZoneId: String, + defaultColumnNameOfCorruptRecord: String = "") = { + this( + CaseInsensitiveMap(parameters), + defaultTimeZoneId, + defaultColumnNameOfCorruptRecord) + } + + protected override def checkedEncoding(enc: String): String = { + val isBlacklisted = JSONOptionsInRead.blacklist.contains(Charset.forName(enc)) + require(multiLine || !isBlacklisted, + s"""The ${enc} encoding must not be included in the blacklist when multiLine is disabled: + |Blacklist: ${JSONOptionsInRead.blacklist.mkString(", ")}""".stripMargin) + + val isLineSepRequired = + multiLine || Charset.forName(enc) == StandardCharsets.UTF_8 || lineSeparator.nonEmpty + require(isLineSepRequired, s"The lineSep option must be specified for the $enc encoding") + + enc + } +} + +private[sql] object JSONOptionsInRead { + // The following encodings are not supported in per-line mode (multiline is false) + // because they cause some problems in reading files with BOM which is supposed to + // present in the files with such encodings. After splitting input files by lines, + // only the first lines will have the BOM which leads to impossibility for reading + // the rest lines. Besides of that, the lineSep option must have the BOM in such + // encodings which can never present between lines. + val blacklist = Seq( + Charset.forName("UTF-16"), + Charset.forName("UTF-32") + ) +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala index 3b04510d2969..e9a0b383b5f4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala @@ -26,7 +26,7 @@ import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext} import org.apache.spark.internal.Logging import org.apache.spark.sql.{AnalysisException, SparkSession} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.json.{JacksonGenerator, JacksonParser, JSONOptions} +import org.apache.spark.sql.catalyst.json.{JacksonGenerator, JacksonParser, JSONOptions, JSONOptionsInRead} import org.apache.spark.sql.catalyst.util.CompressionCodecs import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.sources._ @@ -40,7 +40,7 @@ class JsonFileFormat extends TextBasedFileFormat with DataSourceRegister { sparkSession: SparkSession, options: Map[String, String], path: Path): Boolean = { - val parsedOptions = new JSONOptions( + val parsedOptions = new JSONOptionsInRead( options, sparkSession.sessionState.conf.sessionLocalTimeZone, sparkSession.sessionState.conf.columnNameOfCorruptRecord) @@ -52,7 +52,7 @@ class JsonFileFormat extends TextBasedFileFormat with DataSourceRegister { sparkSession: SparkSession, options: Map[String, String], files: Seq[FileStatus]): Option[StructType] = { - val parsedOptions = new JSONOptions( + val parsedOptions = new JSONOptionsInRead( options, sparkSession.sessionState.conf.sessionLocalTimeZone, sparkSession.sessionState.conf.columnNameOfCorruptRecord) @@ -99,7 +99,7 @@ class JsonFileFormat extends TextBasedFileFormat with DataSourceRegister { val broadcastedHadoopConf = sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf)) - val parsedOptions = new JSONOptions( + val parsedOptions = new JSONOptionsInRead( options, sparkSession.sessionState.conf.sessionLocalTimeZone, sparkSession.sessionState.conf.columnNameOfCorruptRecord) @@ -158,6 +158,11 @@ private[json] class JsonOutputWriter( case None => StandardCharsets.UTF_8 } + if (JSONOptionsInRead.blacklist.contains(encoding)) { + logWarning(s"The JSON file ($path) was written in the encoding ${encoding.displayName()}" + + " which can be read back by Spark only if multiLine is enabled.") + } + private val writer = CodecStreams.createOutputStreamWriter( context, new Path(path), encoding) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index a8a4a524a97f..897424daca0c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -17,9 +17,9 @@ package org.apache.spark.sql.execution.datasources.json -import java.io.{File, FileOutputStream, StringWriter} -import java.nio.charset.{StandardCharsets, UnsupportedCharsetException} -import java.nio.file.{Files, Paths, StandardOpenOption} +import java.io._ +import java.nio.charset.{Charset, StandardCharsets, UnsupportedCharsetException} +import java.nio.file.Files import java.sql.{Date, Timestamp} import java.util.Locale @@ -2262,7 +2262,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { withTempPath { path => val df = spark.createDataset(Seq(("Dog", 42))) df.write - .options(Map("encoding" -> encoding, "lineSep" -> "\n")) + .options(Map("encoding" -> encoding)) .json(path.getCanonicalPath) checkEncoding( @@ -2286,16 +2286,22 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { test("SPARK-23723: wrong output encoding") { val encoding = "UTF-128" - val exception = intercept[UnsupportedCharsetException] { + val exception = intercept[SparkException] { withTempPath { path => val df = spark.createDataset(Seq((0))) df.write - .options(Map("encoding" -> encoding, "lineSep" -> "\n")) + .options(Map("encoding" -> encoding)) .json(path.getCanonicalPath) } } - assert(exception.getMessage == encoding) + val baos = new ByteArrayOutputStream() + val ps = new PrintStream(baos, true, "UTF-8") + exception.printStackTrace(ps) + ps.flush() + + assert(baos.toString.contains( + "java.nio.charset.UnsupportedCharsetException: UTF-128")) } test("SPARK-23723: read back json in UTF-16LE") { @@ -2316,18 +2322,17 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { test("SPARK-23723: write json in UTF-16/32 with multiline off") { Seq("UTF-16", "UTF-32").foreach { encoding => withTempPath { path => - val ds = spark.createDataset(Seq( - ("a", 1), ("b", 2), ("c", 3)) - ).repartition(2) - val e = intercept[IllegalArgumentException] { - ds.write - .option("encoding", encoding) - .option("multiline", "false") - .format("json").mode("overwrite") - .save(path.getCanonicalPath) - }.getMessage - assert(e.contains( - s"$encoding encoding in the blacklist is not allowed when multiLine is disabled")) + val ds = spark.createDataset(Seq(("a", 1))).repartition(1) + ds.write + .option("encoding", encoding) + .option("multiline", false) + .json(path.getCanonicalPath) + val jsonFiles = path.listFiles().filter(_.getName.endsWith("json")) + jsonFiles.foreach { jsonFile => + val readback = Files.readAllBytes(jsonFile.toPath) + val expected = ("""{"_1":"a","_2":1}""" + "\n").getBytes(Charset.forName(encoding)) + assert(readback === expected) + } } } } @@ -2476,4 +2481,17 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { checkAnswer(df, Row(Row(1, "string")) :: Row(Row(2, null)) :: Row(null) :: Nil) } } + + test("SPARK-24190: restrictions for JSONOptions in read") { + for (encoding <- Set("UTF-16", "UTF-32")) { + val exception = intercept[IllegalArgumentException] { + spark.read + .option("encoding", encoding) + .option("multiLine", false) + .json(testFile("test-data/utf16LE.json")) + .count() + } + assert(exception.getMessage.contains("encoding must not be included in the blacklist")) + } + } }