diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 738d8fee891d..f1be2352ea32 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -973,7 +973,7 @@ object SQLConf { buildConf("spark.sql.streaming.commitProtocolClass") .internal() .stringConf - .createWithDefault("org.apache.spark.sql.execution.streaming.ManifestFileCommitProtocol") + .createWithDefault("org.apache.spark.sql.execution.streaming.StagingFileCommitProtocol") val STREAMING_MULTIPLE_WATERMARK_POLICY = buildConf("spark.sql.streaming.multipleWatermarkPolicy") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala index 7c6ab4bc922f..32cff274e569 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala @@ -218,7 +218,7 @@ object FileFormatWriter extends Logging { hadoopConf.set("mapreduce.task.id", taskAttemptId.getTaskID.toString) hadoopConf.set("mapreduce.task.attempt.id", taskAttemptId.toString) hadoopConf.setBoolean("mapreduce.task.ismap", true) - hadoopConf.setInt("mapreduce.task.partition", 0) + hadoopConf.setInt("mapreduce.task.partition", sparkPartitionId) new TaskAttemptContextImpl(hadoopConf, taskAttemptId) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala index b3d12f67b5d6..b837e76d5fda 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala @@ -113,7 +113,7 @@ class FileStreamSink( outputPath = path) committer match { - case manifestCommitter: ManifestFileCommitProtocol => + case manifestCommitter: ManifestCommitProtocol => manifestCommitter.setupManifestOptions(fileLog, batchId) case _ => // Do nothing } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ManifestCommitProtocol.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ManifestCommitProtocol.scala new file mode 100644 index 000000000000..52c08979ba5a --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ManifestCommitProtocol.scala @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming + +trait ManifestCommitProtocol { + @transient protected var fileLog: MetadataLog[Array[SinkFileStatus]] = _ + protected var batchId: Long = _ + /** + * Sets up the manifest log output and the batch id for this job. + * Must be called before any other function. + */ + def setupManifestOptions(fileLog: MetadataLog[Array[SinkFileStatus]], batchId: Long): Unit = { + this.fileLog = fileLog + this.batchId = batchId + } + +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ManifestFileCommitProtocol.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ManifestFileCommitProtocol.scala index 92191c8b64b7..2ce8dde4232b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ManifestFileCommitProtocol.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ManifestFileCommitProtocol.scala @@ -35,23 +35,12 @@ import org.apache.spark.internal.io.FileCommitProtocol.TaskCommitMessage * @param path path to write the final output to. */ class ManifestFileCommitProtocol(jobId: String, path: String) - extends FileCommitProtocol with Serializable with Logging { + extends FileCommitProtocol with Serializable with Logging + with ManifestCommitProtocol { // Track the list of files added by a task, only used on the executors. @transient private var addedFiles: ArrayBuffer[String] = _ - @transient private var fileLog: FileStreamSinkLog = _ - private var batchId: Long = _ - - /** - * Sets up the manifest log output and the batch id for this job. - * Must be called before any other function. - */ - def setupManifestOptions(fileLog: FileStreamSinkLog, batchId: Long): Unit = { - this.fileLog = fileLog - this.batchId = batchId - } - override def setupJob(jobContext: JobContext): Unit = { require(fileLog != null, "setupManifestOptions must be called before this function") // Do nothing diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StagingFileCommitProtocol.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StagingFileCommitProtocol.scala new file mode 100644 index 000000000000..72868a71098a --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StagingFileCommitProtocol.scala @@ -0,0 +1,141 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming + +import org.apache.hadoop.fs.{FileAlreadyExistsException, FileContext, Path} +import org.apache.hadoop.mapreduce.{JobContext, TaskAttemptContext} + +import org.apache.spark.internal.Logging +import org.apache.spark.internal.io.FileCommitProtocol +import org.apache.spark.internal.io.FileCommitProtocol.TaskCommitMessage + +class StagingFileCommitProtocol(jobId: String, path: String) + extends FileCommitProtocol with Serializable with Logging + with ManifestCommitProtocol { + private var stagingDir: Option[Path] = None + + + def jobStagingDir: Path = { + new Path(new Path(path, "staging"), s"job-$jobId") + } + + override def setupJob(jobContext: JobContext): Unit = { + jobStagingDir.getFileSystem(jobContext.getConfiguration).delete(jobStagingDir, true) + logInfo(s"Job $jobId set up") + } + + + override def setupTask(taskContext: TaskAttemptContext): Unit = { + stagingDir = Some(new Path(jobStagingDir, s"partition-${partition(taskContext)}")) + stagingDir.get.getFileSystem(taskContext.getConfiguration).delete(stagingDir.get, true) + logInfo(s"Task set up to handle partition ${partition(taskContext)} in job $jobId") + + } + + private def partition(taskContext: TaskAttemptContext) = { + taskContext.getConfiguration.getInt("mapreduce.task.partition", -1) + } + + + override def commitJob(jobContext: JobContext, taskCommits: Seq[TaskCommitMessage]): Unit = { + val fs = jobStagingDir.getFileSystem(jobContext.getConfiguration) + val fileCtx = FileContext.getFileContext + + def moveIfPossible(next: Path, target: Path) = { + try { + fileCtx.rename(next, target) + } catch { + case _: FileAlreadyExistsException => + val status = fileCtx.getFileStatus(target) + logWarning(s"File ${target.toUri.toASCIIString} has already been generated " + + s"earlier (${status.toString}), deleting instead of moving " + + s"recently generated file: ${fileCtx.getFileStatus(target).toString}") + fileCtx.delete(next, false) + } + } + + def moveEach(from: Path, to: String) = { + val files = fs.listFiles(from, true) + val statuses = Array.newBuilder[SinkFileStatus] + while (files.hasNext) { + val next = files.next().getPath + val target = if (next.getParent.getName.startsWith(outputPartitionPrefix)) { + val subdir = next.getParent.getName.substring(outputPartitionPrefix.length) + .replaceAll(subdirEscapeSequence, "/") + val outputPartition = new Path(to, subdir) + fs.mkdirs(outputPartition) + new Path(outputPartition, next.getName) + } else { + new Path(to, next.getName) + } + moveIfPossible(next, target) + statuses += SinkFileStatus(fs.getFileStatus(target)) + } + if (fileLog.add(batchId, statuses.result)) { + logInfo(s"Job $jobId committed") + } else { + throw new IllegalStateException(s"Race while writing batch $batchId") + } + } + + moveEach(jobStagingDir, path) + + Seq() + } + + override def abortJob(jobContext: JobContext): Unit = {} + + + private var fileCounter: Int = -1 + + private def nextCounter: Int = { + fileCounter += 1 + fileCounter + } + + private val outputPartitionPrefix = "part_prefix_" + + private val subdirEscapeSequence = "___per___" + + override def newTaskTempFile( + taskContext: TaskAttemptContext, dir: Option[String], ext: String): String = { + val staging = stagingDir.getOrElse( + throw new IllegalStateException("Staging dir needs to be initilized in setupTask()")) + val targetDir = dir.map(d => new Path(staging, stagingReplacementDir(d))).getOrElse(staging) + val res = new Path(targetDir, s"part-j$jobId-p${partition(taskContext)}-c$nextCounter$ext") + .toString + logInfo(s"New file generated $res") + res + } + + private def stagingReplacementDir(d: String) = { + outputPartitionPrefix + d.replaceAll("/", subdirEscapeSequence) + } + + override def newTaskTempFileAbsPath( + taskContext: TaskAttemptContext, absoluteDir: String, ext: String): String = { + throw new UnsupportedOperationException( + s"$this does not support adding files with an absolute path") + } + + override def commitTask(taskContext: TaskAttemptContext): TaskCommitMessage = { + new TaskCommitMessage(None) + } + + override def abortTask(taskContext: TaskAttemptContext): Unit = {} +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/StagingFileCommitProtocolSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/StagingFileCommitProtocolSuite.scala new file mode 100644 index 000000000000..6853daffb9ee --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/StagingFileCommitProtocolSuite.scala @@ -0,0 +1,242 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming + +import java.io.File + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs._ +import org.apache.hadoop.mapreduce._ +import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl +import org.mockito.{Matchers, Mockito} +import org.mockito.Mockito.when +import org.scalatest.BeforeAndAfter + +import org.apache.spark.SparkFunSuite +import org.apache.spark.internal.io.FileCommitProtocol.EmptyTaskCommitMessage +import org.apache.spark.sql.test.TestSparkSession +import org.apache.spark.util.Utils + +class StagingFileCommitProtocolSuite extends SparkFunSuite with BeforeAndAfter { + val array: Array[Byte] = new Array[Byte](1000) + val taskId = 333 + val taskId2 = 334 + val jobID = 123 + val attemptId: Int = 444 + val attemptIdNext: Int = 445 + val taskAttemptId: TaskAttemptID = new TaskAttemptID("SPARK", jobID, true, taskId, attemptId) + val taskAttemptId2: TaskAttemptID = new TaskAttemptID("SPARK", jobID, true, taskId2, attemptId) + val basePath = Utils.createTempDir().getCanonicalFile.toString + val protocol = newCommitProtocol(jobID) + val hadoopConf = createConf(0, taskAttemptId) + val tx = new TaskAttemptContextImpl(hadoopConf, taskAttemptId) + val fs = new Path(basePath).getFileSystem(hadoopConf) + val ctx: JobContext = Job.getInstance(hadoopConf) + + val session = new TestSparkSession() + def newCommitProtocol(batchId: Int): StagingFileCommitProtocol = { + val p = new StagingFileCommitProtocol(batchId.toString, basePath) + val logClass = classOf[MetadataLog[Array[SinkFileStatus]]] + val log: MetadataLog[Array[SinkFileStatus]] = Mockito.mock(logClass) + when(log.add(Matchers.anyInt, Matchers.any(classOf[Array[SinkFileStatus]]))) + .thenReturn(true) + + p.setupManifestOptions( + log, + batchId) + p + } + + def createConf(partition: Int, taskAttemptId: TaskAttemptID): Configuration = { + val hc = new Configuration() + hc.set("mapreduce.job.id", jobID.toString) + hc.set("mapreduce.task.id", taskAttemptId.getTaskID.toString) + hc.set("mapreduce.task.attempt.id", taskAttemptId.toString) + hc.setBoolean("mapreduce.task.ismap", true) + hc.setInt("mapreduce.task.partition", partition) + session + hc + } + + after { + Utils.deleteRecursively(new File(basePath)) + } + + test("file is generated on job commit") { + protocol.setupJob(ctx) + + protocol.setupTask(tx) + val fileName = protocol.newTaskTempFile(tx, None, "ext") + writeToFile(fileName, "data") + + protocol.onTaskCommit(EmptyTaskCommitMessage) + + protocol.commitJob(ctx, Seq(EmptyTaskCommitMessage)) + + assert(fileContents == Set("data")) + } + + test("file is generated into partition subdirectory") { + protocol.setupJob(ctx) + + protocol.setupTask(tx) + val fileName = protocol.newTaskTempFile(tx, Some("subdir"), "ext") + writeToFile(fileName, "data") + + protocol.onTaskCommit(EmptyTaskCommitMessage) + + protocol.commitJob(ctx, Seq(EmptyTaskCommitMessage)) + + assert(fileContents(new Path(basePath, "subdir")) == Set("data")) + } + + test("before job commit file is not visible") { + protocol.setupJob(ctx) + + protocol.setupTask(tx) + val fileName = protocol.newTaskTempFile(tx, None, "ext") + writeToFile(fileName, "data") + + protocol.onTaskCommit(EmptyTaskCommitMessage) + + assert(fileContents == Set()) + + } + + test("2 tasks can write in the same job") { + protocol.setupJob(ctx) + protocol.setupTask(tx) + + val protocol2 = newCommitProtocol(jobID) + val tx2 = new TaskAttemptContextImpl(createConf(1, taskAttemptId2), taskAttemptId2) + protocol2.setupTask(tx2) + + val fileName = protocol.newTaskTempFile(tx, None, "ext") + writeToFile(fileName, "data0") + + val fileName2 = protocol2.newTaskTempFile(tx2, None, "ext") + writeToFile(fileName2, "data1") + + protocol.onTaskCommit(EmptyTaskCommitMessage) + protocol2.onTaskCommit(EmptyTaskCommitMessage) + + protocol.commitJob(ctx, Seq()) + assert(fileContents == Set("data0", "data1")) + } + + test("same task can be executed twice") { + protocol.setupJob(ctx) + protocol.setupTask(tx) + + val fileName = protocol.newTaskTempFile(tx, None, "ext") + writeToFile(fileName, "data") + + protocol.onTaskCommit(EmptyTaskCommitMessage) + + protocol.commitJob(ctx, Seq()) + + val protocol2 = newCommitProtocol(jobID) + val attempt = new TaskAttemptID("SPARK", jobID, true, taskId, attemptIdNext) + hadoopConf.set("mapreduce.task.attempt.id", attempt.toString) + val tx2 = new TaskAttemptContextImpl(hadoopConf, attempt) + protocol2.setupJob(tx2) + protocol2.setupTask(tx2) + + val fileName2 = protocol2.newTaskTempFile(tx2, None, "ext") + writeToFile(fileName2, "data") + protocol2.onTaskCommit(EmptyTaskCommitMessage) + protocol2.commitJob(tx2, Seq()) + + assert(fileContents == Set("data")) + + } + + test("task without job commit can be restarted") { + protocol.setupJob(ctx) + protocol.setupTask(tx) + + val fileName = protocol.newTaskTempFile(tx, None, "ext") + writeToFile(fileName, "data") + + protocol.onTaskCommit(EmptyTaskCommitMessage) + + val protocol2 = newCommitProtocol(jobID) + val attempt = new TaskAttemptID("SPARK", jobID, true, taskId, attemptIdNext) + hadoopConf.set("mapreduce.task.attempt.id", attempt.toString) + val tx2 = new TaskAttemptContextImpl(hadoopConf, attempt) + protocol2.setupJob(tx2) + protocol2.setupTask(tx2) + + val fileName2 = protocol2.newTaskTempFile(tx2, None, "ext") + writeToFile(fileName2, "data") + protocol2.onTaskCommit(EmptyTaskCommitMessage) + protocol2.commitJob(tx2, Seq()) + + assert(fileContents == Set("data")) + + } + + + test("multiple files can be generated by same task") { + protocol.setupJob(ctx) + + protocol.setupTask(tx) + val fileName = protocol.newTaskTempFile(tx, None, "ext") + writeToFile(fileName, "data0") + + val fileName2 = protocol.newTaskTempFile(tx, None, "ext") + writeToFile(fileName2, "data1") + + protocol.onTaskCommit(EmptyTaskCommitMessage) + + protocol.commitJob(ctx, Seq(EmptyTaskCommitMessage)) + + assert(fileContents == Set("data0", "data1")) + } + + + + private def fileContents: Set[String] = { + val path = new Path(basePath) + fileContents(path) + } + + private def fileContents(path: Path): Set[String] = { + val files: RemoteIterator[LocatedFileStatus] = fs.listFiles(path, false) + + val fileList = Seq.newBuilder[LocatedFileStatus] + while (files.hasNext) { + fileList += files.next() + } + val allData = fileList.result().map(f => { + val stream = fs.open(f.getPath) + val length = stream.read(array) + array.slice(0, length).map(_.toChar).mkString + }) + allData.toSet + } + + private def writeToFile(fileName: String, data: String) = { + val file = new Path(fileName) + val fs = file.getFileSystem(hadoopConf) + val os = fs.create(file) + os.write(data.getBytes()) + os.close() + } + +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkUnitSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkUnitSuite.scala new file mode 100644 index 000000000000..70189ef05850 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkUnitSuite.scala @@ -0,0 +1,136 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.streaming + +import java.io.{File, FilenameFilter} + +import com.google.common.io.PatternFilenameFilter +import org.apache.commons.io.FileUtils +import org.apache.hadoop.mapreduce.JobContext +import org.scalatest.BeforeAndAfter +import scala.io.Source + +import org.apache.spark.SparkException +import org.apache.spark.internal.io.FileCommitProtocol.TaskCommitMessage +import org.apache.spark.sql.execution.{QueryExecution, SQLExecution} +import org.apache.spark.sql.execution.datasources.text.TextFileFormat +import org.apache.spark.sql.execution.streaming.{FileStreamSink, ManifestFileCommitProtocol, StagingFileCommitProtocol} +import org.apache.spark.sql.internal.SQLConf.STREAMING_FILE_COMMIT_PROTOCOL_CLASS + + +class FailingManifestFileCommitProtocol(jobId: String, path: String) + extends StagingFileCommitProtocol(jobId, path) { + override def commitJob(jobContext: JobContext, taskCommits: Seq[TaskCommitMessage]): Unit = { + logError("Skipping job commit simulating ungraceful shutdown") + } +} + +class ExceptionThrowingManifestFileCommitProtocol(jobId: String, path: String) + extends StagingFileCommitProtocol(jobId, path) { + override def commitJob(jobContext: JobContext, taskCommits: Seq[TaskCommitMessage]): Unit = { + throw new IllegalStateException("Simulating exception on job commit") + } +} + + +class FileStreamSinkUnitSuite extends StreamTest with BeforeAndAfter { + + import testImplicits._ + + val dir = new File("destination_path") + val fruits = Seq("apple", "peach", "citron") + + after { + sqlContext.streams.active.foreach(_.stop()) + } + + before { + FileUtils.deleteQuietly(dir) + } + + test("add batch results in files") { + val df = fruits.toDF() + val sink = new FileStreamSink(spark, dir.getName, new TextFileFormat(), Seq.empty, Map.empty) + SQLExecution.withNewExecutionId(spark, new QueryExecution(spark, df.logicalPlan)) { + sink.addBatch(1, df) + } + + assertResults(dir, fruits) + } + + test("add same batch again will not duplicate results") { + val df = fruits.toDF() + val sink = new FileStreamSink(spark, dir.getName, new TextFileFormat(), Seq.empty, Map.empty) + SQLExecution.withNewExecutionId(spark, new QueryExecution(spark, df.logicalPlan)) { + sink.addBatch(1, df) + sink.addBatch(1, df) + } + + assertResults(dir, fruits) + } + + test("adding batch again after job commit failure should not duplicate items") { + val df = fruits.toDF() + var sink = new FileStreamSink(spark, dir.getName, new TextFileFormat(), Seq.empty, Map.empty) + sqlContext.setConf(STREAMING_FILE_COMMIT_PROTOCOL_CLASS.key, + classOf[FailingManifestFileCommitProtocol].getCanonicalName) + SQLExecution.withNewExecutionId(spark, new QueryExecution(spark, df.logicalPlan)) { + sink.addBatch(1, df) + } + + sqlContext.setConf(STREAMING_FILE_COMMIT_PROTOCOL_CLASS.key, + STREAMING_FILE_COMMIT_PROTOCOL_CLASS.defaultValueString) + SQLExecution.withNewExecutionId(spark, new QueryExecution(spark, df.logicalPlan)) { + sink = new FileStreamSink(spark, dir.getName, new TextFileFormat(), Seq.empty, Map.empty) + sink.addBatch(1, df) + } + + assertResults(dir, fruits) + } + + test("adding batch again after job commit throwing exception should not duplicate items") { + val df = fruits.toDF() + var sink = new FileStreamSink(spark, dir.getName, new TextFileFormat(), Seq.empty, Map.empty) + sqlContext.setConf(STREAMING_FILE_COMMIT_PROTOCOL_CLASS.key, + classOf[ExceptionThrowingManifestFileCommitProtocol].getCanonicalName) + intercept[SparkException] { + SQLExecution.withNewExecutionId(spark, new QueryExecution(spark, df.logicalPlan)) { + sink.addBatch(1, df) + } + } + + sqlContext.setConf(STREAMING_FILE_COMMIT_PROTOCOL_CLASS.key, + STREAMING_FILE_COMMIT_PROTOCOL_CLASS.defaultValueString) + SQLExecution.withNewExecutionId(spark, new QueryExecution(spark, df.logicalPlan)) { + sink = new FileStreamSink(spark, dir.getName, new TextFileFormat(), Seq.empty, Map.empty) + sink.addBatch(1, df) + } + + assertResults(dir, fruits) + } + + private def assertResults(dir: File, fruits: Seq[String]) = { + val output = dir.listFiles(new PatternFilenameFilter("part-.*")).flatMap { + file => + val source = Source.fromFile(file) + source.getLines().toSeq + }.toSeq.sortBy(x => x) + + assert(output == fruits.sortBy(x => x)) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamStressSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamStressSuite.scala index 28412ea07a75..e94cfd512718 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamStressSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamStressSuite.scala @@ -151,6 +151,6 @@ class FileStreamStressSuite extends StreamTest { } logError(s"Stream restarted $failures times.") - assert(spark.read.parquet(outputDir).distinct().count() == numRecords) + assert(spark.read.parquet(outputDir).count() == numRecords) } }