diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CsvOutputWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CsvOutputWriter.scala index 3ff36bfde3cc..2b549536ae48 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CsvOutputWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CsvOutputWriter.scala @@ -33,25 +33,17 @@ class CsvOutputWriter( context: TaskAttemptContext, params: CSVOptions) extends OutputWriter with Logging { - private var univocityGenerator: Option[UnivocityGenerator] = None + private val charset = Charset.forName(params.charset) + + private val writer = CodecStreams.createOutputStreamWriter(context, new Path(path), charset) + + private val gen = new UnivocityGenerator(dataSchema, writer, params) if (params.headerFlag) { - val gen = getGen() gen.writeHeaders() } - private def getGen(): UnivocityGenerator = univocityGenerator.getOrElse { - val charset = Charset.forName(params.charset) - val os = CodecStreams.createOutputStreamWriter(context, new Path(path), charset) - val newGen = new UnivocityGenerator(dataSchema, os, params) - univocityGenerator = Some(newGen) - newGen - } - - override def write(row: InternalRow): Unit = { - val gen = getGen() - gen.write(row) - } + override def write(row: InternalRow): Unit = gen.write(row) - override def close(): Unit = univocityGenerator.foreach(_.close()) + override def close(): Unit = gen.close() } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonOutputWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonOutputWriter.scala index b3cd570cfb1c..dfd84e344eb2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonOutputWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonOutputWriter.scala @@ -44,20 +44,18 @@ class JsonOutputWriter( " which can be read back by Spark only if multiLine is enabled.") } - private var jacksonGenerator: Option[JacksonGenerator] = None + private val writer = CodecStreams.createOutputStreamWriter(context, new Path(path), encoding) - override def write(row: InternalRow): Unit = { - val gen = jacksonGenerator.getOrElse { - val os = CodecStreams.createOutputStreamWriter(context, new Path(path), encoding) - // create the Generator without separator inserted between 2 records - val newGen = new JacksonGenerator(dataSchema, os, options) - jacksonGenerator = Some(newGen) - newGen - } + // create the Generator without separator inserted between 2 records + private[this] val gen = new JacksonGenerator(dataSchema, writer, options) + override def write(row: InternalRow): Unit = { gen.write(row) gen.writeLineEnding() } - override def close(): Unit = jacksonGenerator.foreach(_.close()) + override def close(): Unit = { + gen.close() + writer.close() + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextOutputWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextOutputWriter.scala index faf6e573105f..2b1b81f60ceb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextOutputWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextOutputWriter.scala @@ -16,8 +16,6 @@ */ package org.apache.spark.sql.execution.datasources.text -import java.io.OutputStream - import org.apache.hadoop.fs.Path import org.apache.hadoop.mapreduce.TaskAttemptContext @@ -32,23 +30,17 @@ class TextOutputWriter( context: TaskAttemptContext) extends OutputWriter { - private var outputStream: Option[OutputStream] = None + private val writer = CodecStreams.createOutputStream(context, new Path(path)) override def write(row: InternalRow): Unit = { - val os = outputStream.getOrElse { - val newStream = CodecStreams.createOutputStream(context, new Path(path)) - outputStream = Some(newStream) - newStream - } - if (!row.isNullAt(0)) { val utf8string = row.getUTF8String(0) - utf8string.writeTo(os) + utf8string.writeTo(writer) } - os.write(lineSeparator) + writer.write(lineSeparator) } override def close(): Unit = { - outputStream.foreach(_.close()) + writer.close() } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ManifestFileCommitProtocol.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ManifestFileCommitProtocol.scala index be7973b9d930..f6cc8116c6c4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ManifestFileCommitProtocol.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ManifestFileCommitProtocol.scala @@ -22,7 +22,7 @@ import java.util.UUID import scala.collection.mutable.ArrayBuffer -import org.apache.hadoop.fs.{FileSystem, Path} +import org.apache.hadoop.fs.Path import org.apache.hadoop.mapreduce.{JobContext, TaskAttemptContext} import org.apache.spark.internal.Logging @@ -89,7 +89,9 @@ class ManifestFileCommitProtocol(jobId: String, path: String) try { val fs = path.getFileSystem(jobContext.getConfiguration) // this is to make sure the file can be seen from driver as well - deleteIfExists(fs, path) + if (fs.exists(path)) { + fs.delete(path, false) + } } catch { case e: IOException => logWarning(s"Fail to remove temporary file $path, continue removing next.", e) @@ -137,14 +139,7 @@ class ManifestFileCommitProtocol(jobId: String, path: String) if (addedFiles.nonEmpty) { val fs = new Path(addedFiles.head).getFileSystem(taskContext.getConfiguration) val statuses: Seq[SinkFileStatus] = - addedFiles.flatMap { f => - val path = new Path(f) - if (fs.exists(path)) { - Some(SinkFileStatus(fs.getFileStatus(path))) - } else { - None - } - } + addedFiles.map(f => SinkFileStatus(fs.getFileStatus(new Path(f)))) new TaskCommitMessage(statuses) } else { new TaskCommitMessage(Seq.empty[SinkFileStatus]) @@ -155,13 +150,7 @@ class ManifestFileCommitProtocol(jobId: String, path: String) // best effort cleanup of incomplete files if (addedFiles.nonEmpty) { val fs = new Path(addedFiles.head).getFileSystem(taskContext.getConfiguration) - addedFiles.foreach { file => deleteIfExists(fs, new Path(file)) } - } - } - - private def deleteIfExists(fs: FileSystem, path: Path, recursive: Boolean = false): Unit = { - if (fs.exists(path)) { - fs.delete(path, recursive) + addedFiles.foreach { file => fs.delete(new Path(file), false) } } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala index a6c3a51858ae..a6e58cec1036 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala @@ -2073,15 +2073,6 @@ class CSVSuite extends QueryTest with SharedSparkSession with TestCsvData { } } - test("do not produce empty files for empty partitions") { - withTempPath { dir => - val path = dir.getCanonicalPath - spark.emptyDataset[String].write.csv(path) - val files = new File(path).listFiles() - assert(!files.exists(_.getName.endsWith("csv"))) - } - } - test("Do not reuse last good value for bad input field") { val schema = StructType( StructField("col1", StringType) :: diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index 3574aa266b35..e3e0195f08e3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -2471,15 +2471,6 @@ class JsonSuite extends QueryTest with SharedSparkSession with TestJsonData { emptyString(BinaryType, "".getBytes(StandardCharsets.UTF_8)) } - test("do not produce empty files for empty partitions") { - withTempPath { dir => - val path = dir.getCanonicalPath - spark.emptyDataset[String].write.json(path) - val files = new File(path).listFiles() - assert(!files.exists(_.getName.endsWith("json"))) - } - } - test("return partial result for bad records") { val schema = "a double, b array, c string, _corrupt_record string" val badRecords = Seq( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/TextSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/TextSuite.scala index 62a779528cec..539ff0d0e905 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/TextSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/TextSuite.scala @@ -233,13 +233,4 @@ class TextSuite extends QueryTest with SharedSparkSession { assert(data(3) == Row("\"doh\"")) assert(data.length == 4) } - - test("do not produce empty files for empty partitions") { - withTempPath { dir => - val path = dir.getCanonicalPath - spark.emptyDataset[String].write.text(path) - val files = new File(path).listFiles() - assert(!files.exists(_.getName.endsWith("txt"))) - } - } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala index f04da8bfc448..9bce7f3568e8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala @@ -525,6 +525,54 @@ abstract class FileStreamSinkSuite extends StreamTest { } } } + + test("Handle FileStreamSink metadata correctly for empty partition") { + Seq("parquet", "orc", "text", "json").foreach { format => + val inputData = MemoryStream[String] + val df = inputData.toDF() + + withTempDir { outputDir => + withTempDir { checkpointDir => + var query: StreamingQuery = null + try { + // repartition to more than the input to leave empty partitions + query = + df.repartition(10) + .writeStream + .option("checkpointLocation", checkpointDir.getCanonicalPath) + .format(format) + .start(outputDir.getCanonicalPath) + + inputData.addData("1", "2", "3") + inputData.addData("4", "5") + + failAfter(streamingTimeout) { + query.processAllAvailable() + } + } finally { + if (query != null) { + query.stop() + } + } + + val fs = new Path(outputDir.getCanonicalPath).getFileSystem( + spark.sessionState.newHadoopConf()) + val sinkLog = new FileStreamSinkLog(FileStreamSinkLog.VERSION, spark, + outputDir.getCanonicalPath) + + val allFiles = sinkLog.allFiles() + // only files from non-empty partition should be logged + assert(allFiles.length < 10) + assert(allFiles.forall(file => fs.exists(new Path(file.path)))) + + // the query should be able to read all rows correctly with metadata log + val outputDf = spark.read.format(format).load(outputDir.getCanonicalPath) + .selectExpr("CAST(value AS INT)").as[Int] + checkDatasetUnorderly(outputDf, 1, 2, 3, 4, 5) + } + } + } + } } object PendingCommitFilesTrackingManifestFileCommitProtocol { @@ -600,61 +648,11 @@ class FileStreamSinkV1Suite extends FileStreamSinkSuite { } class FileStreamSinkV2Suite extends FileStreamSinkSuite { - import testImplicits._ - override protected def sparkConf: SparkConf = super .sparkConf .set(SQLConf.USE_V1_SOURCE_LIST, "") - test("SPARK-29999 Handle FileStreamSink metadata correctly for empty partition") { - Seq("parquet", "orc", "text", "json").foreach { format => - val inputData = MemoryStream[String] - val df = inputData.toDF() - - withTempDir { outputDir => - withTempDir { checkpointDir => - var query: StreamingQuery = null - try { - // repartition to more than the input to leave empty partitions - query = - df.repartition(10) - .writeStream - .option("checkpointLocation", checkpointDir.getCanonicalPath) - .format(format) - .start(outputDir.getCanonicalPath) - - inputData.addData("1", "2", "3") - inputData.addData("4", "5") - - failAfter(streamingTimeout) { - query.processAllAvailable() - } - } finally { - if (query != null) { - query.stop() - } - } - - val fs = new Path(outputDir.getCanonicalPath).getFileSystem( - spark.sessionState.newHadoopConf()) - val sinkLog = new FileStreamSinkLog(FileStreamSinkLog.VERSION, spark, - outputDir.getCanonicalPath) - - val allFiles = sinkLog.allFiles() - // only files from non-empty partition should be logged - assert(allFiles.length < 10) - assert(allFiles.forall(file => fs.exists(new Path(file.path)))) - - // the query should be able to read all rows correctly with metadata log - val outputDf = spark.read.format(format).load(outputDir.getCanonicalPath) - .selectExpr("CAST(value AS INT)").as[Int] - checkDatasetUnorderly(outputDf, 1, 2, 3, 4, 5) - } - } - } - } - override def checkQueryExecution(df: DataFrame): Unit = { // Verify that MetadataLogFileIndex is being used and the correct partitioning schema has // been inferred