diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/WriteSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/WriteSupport.java new file mode 100644 index 000000000000..a8a961598bde --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/WriteSupport.java @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2; + +import java.util.Optional; + +import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.sql.SaveMode; +import org.apache.spark.sql.sources.v2.writer.DataSourceV2Writer; +import org.apache.spark.sql.types.StructType; + +/** + * A mix-in interface for {@link DataSourceV2}. Data sources can implement this interface to + * provide data writing ability and save the data to the data source. + */ +@InterfaceStability.Evolving +public interface WriteSupport { + + /** + * Creates an optional {@link DataSourceV2Writer} to save the data to this data source. Data + * sources can return None if there is no writing needed to be done according to the save mode. + * + * @param jobId A unique string for the writing job. It's possible that there are many writing + * jobs running at the same time, and the returned {@link DataSourceV2Writer} should + * use this job id to distinguish itself with writers of other jobs. + * @param schema the schema of the data to be written. + * @param mode the save mode which determines what to do when the data are already in this data + * source, please refer to {@link SaveMode} for more details. + * @param options the options for the returned data source writer, which is an immutable + * case-insensitive string-to-string map. + */ + Optional createWriter( + String jobId, StructType schema, SaveMode mode, DataSourceV2Options options); +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownFilters.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownFilters.java index d6f297c01337..6b0c9d417eea 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownFilters.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownFilters.java @@ -18,7 +18,6 @@ package org.apache.spark.sql.sources.v2.reader; import org.apache.spark.annotation.InterfaceStability; -import org.apache.spark.sql.catalyst.expressions.Expression; import org.apache.spark.sql.sources.Filter; /** diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceV2Writer.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceV2Writer.java new file mode 100644 index 000000000000..8d8e33633fb0 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceV2Writer.java @@ -0,0 +1,88 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2.writer; + +import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SaveMode; +import org.apache.spark.sql.sources.v2.DataSourceV2Options; +import org.apache.spark.sql.sources.v2.WriteSupport; +import org.apache.spark.sql.types.StructType; + +/** + * A data source writer that is returned by + * {@link WriteSupport#createWriter(String, StructType, SaveMode, DataSourceV2Options)}. + * It can mix in various writing optimization interfaces to speed up the data saving. The actual + * writing logic is delegated to {@link DataWriter}. + * + * The writing procedure is: + * 1. Create a writer factory by {@link #createWriterFactory()}, serialize and send it to all the + * partitions of the input data(RDD). + * 2. For each partition, create the data writer, and write the data of the partition with this + * writer. If all the data are written successfully, call {@link DataWriter#commit()}. If + * exception happens during the writing, call {@link DataWriter#abort()}. + * 3. If all writers are successfully committed, call {@link #commit(WriterCommitMessage[])}. If + * some writers are aborted, or the job failed with an unknown reason, call + * {@link #abort(WriterCommitMessage[])}. + * + * Spark won't retry failed writing jobs, users should do it manually in their Spark applications if + * they want to retry. + * + * Please refer to the document of commit/abort methods for detailed specifications. + * + * Note that, this interface provides a protocol between Spark and data sources for transactional + * data writing, but the transaction here is Spark-level transaction, which may not be the + * underlying storage transaction. For example, Spark successfully writes data to a Cassandra data + * source, but Cassandra may need some more time to reach consistency at storage level. + */ +@InterfaceStability.Evolving +public interface DataSourceV2Writer { + + /** + * Creates a writer factory which will be serialized and sent to executors. + */ + DataWriterFactory createWriterFactory(); + + /** + * Commits this writing job with a list of commit messages. The commit messages are collected from + * successful data writers and are produced by {@link DataWriter#commit()}. If this method + * fails(throw exception), this writing job is considered to be failed, and + * {@link #abort(WriterCommitMessage[])} will be called. The written data should only be visible + * to data source readers if this method succeeds. + * + * Note that, one partition may have multiple committed data writers because of speculative tasks. + * Spark will pick the first successful one and get its commit message. Implementations should be + * aware of this and handle it correctly, e.g., have a mechanism to make sure only one data writer + * can commit successfully, or have a way to clean up the data of already-committed writers. + */ + void commit(WriterCommitMessage[] messages); + + /** + * Aborts this writing job because some data writers are failed to write the records and aborted, + * or the Spark job fails with some unknown reasons, or {@link #commit(WriterCommitMessage[])} + * fails. If this method fails(throw exception), the underlying data source may have garbage that + * need to be cleaned manually, but these garbage should not be visible to data source readers. + * + * Unless the abort is triggered by the failure of commit, the given messages should have some + * null slots as there maybe only a few data writers that are committed before the abort + * happens, or some data writers were committed but their commit messages haven't reached the + * driver when the abort is triggered. So this is just a "best effort" for data sources to + * clean up the data left by data writers. + */ + void abort(WriterCommitMessage[] messages); +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriter.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriter.java new file mode 100644 index 000000000000..14261419af6f --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriter.java @@ -0,0 +1,92 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2.writer; + +import org.apache.spark.annotation.InterfaceStability; + +/** + * A data writer returned by {@link DataWriterFactory#createWriter(int, int)} and is + * responsible for writing data for an input RDD partition. + * + * One Spark task has one exclusive data writer, so there is no thread-safe concern. + * + * {@link #write(Object)} is called for each record in the input RDD partition. If one record fails + * the {@link #write(Object)}, {@link #abort()} is called afterwards and the remaining records will + * not be processed. If all records are successfully written, {@link #commit()} is called. + * + * If this data writer succeeds(all records are successfully written and {@link #commit()} + * succeeds), a {@link WriterCommitMessage} will be sent to the driver side and pass to + * {@link DataSourceV2Writer#commit(WriterCommitMessage[])} with commit messages from other data + * writers. If this data writer fails(one record fails to write or {@link #commit()} fails), an + * exception will be sent to the driver side, and Spark will retry this writing task for some times, + * each time {@link DataWriterFactory#createWriter(int, int)} gets a different `attemptNumber`, + * and finally call {@link DataSourceV2Writer#abort(WriterCommitMessage[])} if all retry fail. + * + * Besides the retry mechanism, Spark may launch speculative tasks if the existing writing task + * takes too long to finish. Different from retried tasks, which are launched one by one after the + * previous one fails, speculative tasks are running simultaneously. It's possible that one input + * RDD partition has multiple data writers with different `attemptNumber` running at the same time, + * and data sources should guarantee that these data writers don't conflict and can work together. + * Implementations can coordinate with driver during {@link #commit()} to make sure only one of + * these data writers can commit successfully. Or implementations can allow all of them to commit + * successfully, and have a way to revert committed data writers without the commit message, because + * Spark only accepts the commit message that arrives first and ignore others. + * + * Note that, Currently the type `T` can only be {@link org.apache.spark.sql.Row} for normal data + * source writers, or {@link org.apache.spark.sql.catalyst.InternalRow} for data source writers + * that mix in {@link SupportsWriteInternalRow}. + */ +@InterfaceStability.Evolving +public interface DataWriter { + + /** + * Writes one record. + * + * If this method fails(throw exception), {@link #abort()} will be called and this data writer is + * considered to be failed. + */ + void write(T record); + + /** + * Commits this writer after all records are written successfully, returns a commit message which + * will be send back to driver side and pass to + * {@link DataSourceV2Writer#commit(WriterCommitMessage[])}. + * + * The written data should only be visible to data source readers after + * {@link DataSourceV2Writer#commit(WriterCommitMessage[])} succeeds, which means this method + * should still "hide" the written data and ask the {@link DataSourceV2Writer} at driver side to + * do the final commitment via {@link WriterCommitMessage}. + * + * If this method fails(throw exception), {@link #abort()} will be called and this data writer is + * considered to be failed. + */ + WriterCommitMessage commit(); + + /** + * Aborts this writer if it is failed. Implementations should clean up the data for already + * written records. + * + * This method will only be called if there is one record failed to write, or {@link #commit()} + * failed. + * + * If this method fails(throw exception), the underlying data source may have garbage that need + * to be cleaned by {@link DataSourceV2Writer#abort(WriterCommitMessage[])} or manually, but + * these garbage should not be visible to data source readers. + */ + void abort(); +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriterFactory.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriterFactory.java new file mode 100644 index 000000000000..f812d102bda1 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriterFactory.java @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2.writer; + +import java.io.Serializable; + +import org.apache.spark.annotation.InterfaceStability; + +/** + * A factory of {@link DataWriter} returned by {@link DataSourceV2Writer#createWriterFactory()}, + * which is responsible for creating and initializing the actual data writer at executor side. + * + * Note that, the writer factory will be serialized and sent to executors, then the data writer + * will be created on executors and do the actual writing. So {@link DataWriterFactory} must be + * serializable and {@link DataWriter} doesn't need to be. + */ +@InterfaceStability.Evolving +public interface DataWriterFactory extends Serializable { + + /** + * Returns a data writer to do the actual writing work. + * + * @param partitionId A unique id of the RDD partition that the returned writer will process. + * Usually Spark processes many RDD partitions at the same time, + * implementations should use the partition id to distinguish writers for + * different partitions. + * @param attemptNumber Spark may launch multiple tasks with the same task id. For example, a task + * failed, Spark launches a new task wth the same task id but different + * attempt number. Or a task is too slow, Spark launches new tasks wth the + * same task id but different attempt number, which means there are multiple + * tasks with the same task id running at the same time. Implementations can + * use this attempt number to distinguish writers of different task attempts. + */ + DataWriter createWriter(int partitionId, int attemptNumber); +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/SupportsWriteInternalRow.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/SupportsWriteInternalRow.java new file mode 100644 index 000000000000..a8e95901f3b0 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/SupportsWriteInternalRow.java @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2.writer; + +import org.apache.spark.annotation.Experimental; +import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.catalyst.InternalRow; + +/** + * A mix-in interface for {@link DataSourceV2Writer}. Data source writers can implement this + * interface to write {@link InternalRow} directly and avoid the row conversion at Spark side. + * This is an experimental and unstable interface, as {@link InternalRow} is not public and may get + * changed in the future Spark versions. + */ + +@InterfaceStability.Evolving +@Experimental +@InterfaceStability.Unstable +public interface SupportsWriteInternalRow extends DataSourceV2Writer { + + @Override + default DataWriterFactory createWriterFactory() { + throw new IllegalStateException( + "createWriterFactory should not be called with SupportsWriteInternalRow."); + } + + DataWriterFactory createInternalRowWriterFactory(); +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/WriterCommitMessage.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/WriterCommitMessage.java new file mode 100644 index 000000000000..082d6b5dc409 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/WriterCommitMessage.java @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2.writer; + +import java.io.Serializable; + +import org.apache.spark.annotation.InterfaceStability; + +/** + * A commit message returned by {@link DataWriter#commit()} and will be sent back to the driver side + * as the input parameter of {@link DataSourceV2Writer#commit(WriterCommitMessage[])}. + * + * This is an empty interface, data sources should define their own message class and use it in + * their {@link DataWriter#commit()} and {@link DataSourceV2Writer#commit(WriterCommitMessage[])} + * implementations. + */ +@InterfaceStability.Evolving +public interface WriterCommitMessage extends Serializable {} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index 07347d274854..baa510969fd5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -17,7 +17,8 @@ package org.apache.spark.sql -import java.util.{Locale, Properties} +import java.text.SimpleDateFormat +import java.util.{Date, Locale, Properties, UUID} import scala.collection.JavaConverters._ @@ -29,7 +30,9 @@ import org.apache.spark.sql.catalyst.plans.logical.{InsertIntoTable, LogicalPlan import org.apache.spark.sql.execution.SQLExecution import org.apache.spark.sql.execution.command.DDLUtils import org.apache.spark.sql.execution.datasources.{CreateTable, DataSource, LogicalRelation} +import org.apache.spark.sql.execution.datasources.v2.WriteToDataSourceV2 import org.apache.spark.sql.sources.BaseRelation +import org.apache.spark.sql.sources.v2.{DataSourceV2, DataSourceV2Options, WriteSupport} import org.apache.spark.sql.types.StructType /** @@ -231,12 +234,33 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { assertNotBucketed("save") - runCommand(df.sparkSession, "save") { - DataSource( - sparkSession = df.sparkSession, - className = source, - partitionColumns = partitioningColumns.getOrElse(Nil), - options = extraOptions.toMap).planForWriting(mode, df.logicalPlan) + val cls = DataSource.lookupDataSource(source) + if (classOf[DataSourceV2].isAssignableFrom(cls)) { + cls.newInstance() match { + case ds: WriteSupport => + val options = new DataSourceV2Options(extraOptions.asJava) + // Using a timestamp and a random UUID to distinguish different writing jobs. This is good + // enough as there won't be tons of writing jobs created at the same second. + val jobId = new SimpleDateFormat("yyyyMMddHHmmss", Locale.US) + .format(new Date()) + "-" + UUID.randomUUID() + val writer = ds.createWriter(jobId, df.logicalPlan.schema, mode, options) + if (writer.isPresent) { + runCommand(df.sparkSession, "save") { + WriteToDataSourceV2(writer.get(), df.logicalPlan) + } + } + + case _ => throw new AnalysisException(s"$cls does not support data writing.") + } + } else { + // Code path for data source v1. + runCommand(df.sparkSession, "save") { + DataSource( + sparkSession = df.sparkSession, + className = source, + partitionColumns = partitioningColumns.getOrElse(Nil), + options = extraOptions.toMap).planForWriting(mode, df.logicalPlan) + } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala index f2cda002245e..df5b524485f5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala @@ -18,20 +18,17 @@ package org.apache.spark.sql.execution.datasources.v2 import org.apache.spark.sql.Strategy -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.execution.{FilterExec, ProjectExec, SparkPlan} -import org.apache.spark.sql.execution.datasources.DataSourceStrategy -import org.apache.spark.sql.sources.Filter -import org.apache.spark.sql.sources.v2.reader._ +import org.apache.spark.sql.execution.SparkPlan object DataSourceV2Strategy extends Strategy { - // TODO: write path override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case DataSourceV2Relation(output, reader) => DataSourceV2ScanExec(output, reader) :: Nil + case WriteToDataSourceV2(writer, query) => + WriteToDataSourceV2Exec(writer, planLater(query)) :: Nil + case _ => Nil } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala new file mode 100644 index 000000000000..92c1e1f4a338 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala @@ -0,0 +1,133 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.v2 + +import org.apache.spark.{SparkException, TaskContext} +import org.apache.spark.internal.Logging +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, RowEncoder} +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.sources.v2.writer._ +import org.apache.spark.sql.types.StructType +import org.apache.spark.util.Utils + +/** + * The logical plan for writing data into data source v2. + */ +case class WriteToDataSourceV2(writer: DataSourceV2Writer, query: LogicalPlan) extends LogicalPlan { + override def children: Seq[LogicalPlan] = Seq(query) + override def output: Seq[Attribute] = Nil +} + +/** + * The physical plan for writing data into data source v2. + */ +case class WriteToDataSourceV2Exec(writer: DataSourceV2Writer, query: SparkPlan) extends SparkPlan { + override def children: Seq[SparkPlan] = Seq(query) + override def output: Seq[Attribute] = Nil + + override protected def doExecute(): RDD[InternalRow] = { + val writeTask = writer match { + case w: SupportsWriteInternalRow => w.createInternalRowWriterFactory() + case _ => new RowToInternalRowDataWriterFactory(writer.createWriterFactory(), query.schema) + } + + val rdd = query.execute() + val messages = new Array[WriterCommitMessage](rdd.partitions.length) + + logInfo(s"Start processing data source writer: $writer. " + + s"The input RDD has ${messages.length} partitions.") + + try { + sparkContext.runJob( + rdd, + (context: TaskContext, iter: Iterator[InternalRow]) => + DataWritingSparkTask.run(writeTask, context, iter), + rdd.partitions.indices, + (index, message: WriterCommitMessage) => messages(index) = message + ) + + logInfo(s"Data source writer $writer is committing.") + writer.commit(messages) + logInfo(s"Data source writer $writer committed.") + } catch { + case cause: Throwable => + logError(s"Data source writer $writer is aborting.") + try { + writer.abort(messages) + } catch { + case t: Throwable => + logError(s"Data source writer $writer failed to abort.") + cause.addSuppressed(t) + throw new SparkException("Writing job failed.", cause) + } + logError(s"Data source writer $writer aborted.") + throw new SparkException("Writing job aborted.", cause) + } + + sparkContext.emptyRDD + } +} + +object DataWritingSparkTask extends Logging { + def run( + writeTask: DataWriterFactory[InternalRow], + context: TaskContext, + iter: Iterator[InternalRow]): WriterCommitMessage = { + val dataWriter = writeTask.createWriter(context.partitionId(), context.attemptNumber()) + + // write the data and commit this writer. + Utils.tryWithSafeFinallyAndFailureCallbacks(block = { + iter.foreach(dataWriter.write) + logInfo(s"Writer for partition ${context.partitionId()} is committing.") + val msg = dataWriter.commit() + logInfo(s"Writer for partition ${context.partitionId()} committed.") + msg + })(catchBlock = { + // If there is an error, abort this writer + logError(s"Writer for partition ${context.partitionId()} is aborting.") + dataWriter.abort() + logError(s"Writer for partition ${context.partitionId()} aborted.") + }) + } +} + +class RowToInternalRowDataWriterFactory( + rowWriterFactory: DataWriterFactory[Row], + schema: StructType) extends DataWriterFactory[InternalRow] { + + override def createWriter(partitionId: Int, attemptNumber: Int): DataWriter[InternalRow] = { + new RowToInternalRowDataWriter( + rowWriterFactory.createWriter(partitionId, attemptNumber), + RowEncoder.apply(schema).resolveAndBind()) + } +} + +class RowToInternalRowDataWriter(rowWriter: DataWriter[Row], encoder: ExpressionEncoder[Row]) + extends DataWriter[InternalRow] { + + override def write(record: InternalRow): Unit = rowWriter.write(encoder.fromRow(record)) + + override def commit(): WriterCommitMessage = rowWriter.commit() + + override def abort(): Unit = rowWriter.abort() +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala index f238e565dc2f..092702a1d517 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala @@ -21,6 +21,7 @@ import java.util.{ArrayList, List => JList} import test.org.apache.spark.sql.sources.v2._ +import org.apache.spark.SparkException import org.apache.spark.sql.{AnalysisException, QueryTest, Row} import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.sources.{Filter, GreaterThan} @@ -80,6 +81,74 @@ class DataSourceV2Suite extends QueryTest with SharedSQLContext { } } } + + test("simple writable data source") { + // TODO: java implementation. + Seq(classOf[SimpleWritableDataSource]).foreach { cls => + withTempPath { file => + val path = file.getCanonicalPath + assert(spark.read.format(cls.getName).option("path", path).load().collect().isEmpty) + + spark.range(10).select('id, -'id).write.format(cls.getName) + .option("path", path).save() + checkAnswer( + spark.read.format(cls.getName).option("path", path).load(), + spark.range(10).select('id, -'id)) + + // test with different save modes + spark.range(10).select('id, -'id).write.format(cls.getName) + .option("path", path).mode("append").save() + checkAnswer( + spark.read.format(cls.getName).option("path", path).load(), + spark.range(10).union(spark.range(10)).select('id, -'id)) + + spark.range(5).select('id, -'id).write.format(cls.getName) + .option("path", path).mode("overwrite").save() + checkAnswer( + spark.read.format(cls.getName).option("path", path).load(), + spark.range(5).select('id, -'id)) + + spark.range(5).select('id, -'id).write.format(cls.getName) + .option("path", path).mode("ignore").save() + checkAnswer( + spark.read.format(cls.getName).option("path", path).load(), + spark.range(5).select('id, -'id)) + + val e = intercept[Exception] { + spark.range(5).select('id, -'id).write.format(cls.getName) + .option("path", path).mode("error").save() + } + assert(e.getMessage.contains("data already exists")) + + // test transaction + val failingUdf = org.apache.spark.sql.functions.udf { + var count = 0 + (id: Long) => { + if (count > 5) { + throw new RuntimeException("testing error") + } + count += 1 + id + } + } + // this input data will fail to read middle way. + val input = spark.range(10).select(failingUdf('id).as('i)).select('i, -'i) + val e2 = intercept[SparkException] { + input.write.format(cls.getName).option("path", path).mode("overwrite").save() + } + assert(e2.getMessage.contains("Writing job aborted")) + // make sure we don't have partial data. + assert(spark.read.format(cls.getName).option("path", path).load().collect().isEmpty) + + // test internal row writer + spark.range(5).select('id, -'id).write.format(cls.getName) + .option("path", path).option("internal", "true").mode("overwrite").save() + checkAnswer( + spark.read.format(cls.getName).option("path", path).load(), + spark.range(5).select('id, -'id)) + } + } + } } class SimpleDataSourceV2 extends DataSourceV2 with ReadSupport { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala new file mode 100644 index 000000000000..6fb60f4d848d --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala @@ -0,0 +1,249 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2 + +import java.io.{BufferedReader, InputStreamReader, IOException} +import java.text.SimpleDateFormat +import java.util.{Collections, Date, List => JList, Locale, Optional, UUID} + +import scala.collection.JavaConverters._ + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.{FileSystem, FSDataInputStream, Path} + +import org.apache.spark.SparkContext +import org.apache.spark.sql.{Row, SaveMode} +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.sources.v2.reader.{DataReader, DataSourceV2Reader, ReadTask} +import org.apache.spark.sql.sources.v2.writer._ +import org.apache.spark.sql.types.{DataType, StructType} +import org.apache.spark.util.SerializableConfiguration + +/** + * A HDFS based transactional writable data source. + * Each task writes data to `target/_temporary/jobId/$jobId-$partitionId-$attemptNumber`. + * Each job moves files from `target/_temporary/jobId/` to `target`. + */ +class SimpleWritableDataSource extends DataSourceV2 with ReadSupport with WriteSupport { + + private val schema = new StructType().add("i", "long").add("j", "long") + + class Reader(path: String, conf: Configuration) extends DataSourceV2Reader { + override def readSchema(): StructType = schema + + override def createReadTasks(): JList[ReadTask[Row]] = { + val dataPath = new Path(path) + val fs = dataPath.getFileSystem(conf) + if (fs.exists(dataPath)) { + fs.listStatus(dataPath).filterNot { status => + val name = status.getPath.getName + name.startsWith("_") || name.startsWith(".") + }.map { f => + val serializableConf = new SerializableConfiguration(conf) + new SimpleCSVReadTask(f.getPath.toUri.toString, serializableConf): ReadTask[Row] + }.toList.asJava + } else { + Collections.emptyList() + } + } + } + + class Writer(jobId: String, path: String, conf: Configuration) extends DataSourceV2Writer { + override def createWriterFactory(): DataWriterFactory[Row] = { + new SimpleCSVDataWriterFactory(path, jobId, new SerializableConfiguration(conf)) + } + + override def commit(messages: Array[WriterCommitMessage]): Unit = { + val finalPath = new Path(path) + val jobPath = new Path(new Path(finalPath, "_temporary"), jobId) + val fs = jobPath.getFileSystem(conf) + try { + for (file <- fs.listStatus(jobPath).map(_.getPath)) { + val dest = new Path(finalPath, file.getName) + if(!fs.rename(file, dest)) { + throw new IOException(s"failed to rename($file, $dest)") + } + } + } finally { + fs.delete(jobPath, true) + } + } + + override def abort(messages: Array[WriterCommitMessage]): Unit = { + val jobPath = new Path(new Path(path, "_temporary"), jobId) + val fs = jobPath.getFileSystem(conf) + fs.delete(jobPath, true) + } + } + + class InternalRowWriter(jobId: String, path: String, conf: Configuration) + extends Writer(jobId, path, conf) with SupportsWriteInternalRow { + + override def createWriterFactory(): DataWriterFactory[Row] = { + throw new IllegalArgumentException("not expected!") + } + + override def createInternalRowWriterFactory(): DataWriterFactory[InternalRow] = { + new InternalRowCSVDataWriterFactory(path, jobId, new SerializableConfiguration(conf)) + } + } + + override def createReader(options: DataSourceV2Options): DataSourceV2Reader = { + val path = new Path(options.get("path").get()) + val conf = SparkContext.getActive.get.hadoopConfiguration + new Reader(path.toUri.toString, conf) + } + + override def createWriter( + jobId: String, + schema: StructType, + mode: SaveMode, + options: DataSourceV2Options): Optional[DataSourceV2Writer] = { + assert(DataType.equalsStructurally(schema.asNullable, this.schema.asNullable)) + assert(!SparkContext.getActive.get.conf.getBoolean("spark.speculation", false)) + + val path = new Path(options.get("path").get()) + val internal = options.get("internal").isPresent + val conf = SparkContext.getActive.get.hadoopConfiguration + val fs = path.getFileSystem(conf) + + if (mode == SaveMode.ErrorIfExists) { + if (fs.exists(path)) { + throw new RuntimeException("data already exists.") + } + } + if (mode == SaveMode.Ignore) { + if (fs.exists(path)) { + return Optional.empty() + } + } + if (mode == SaveMode.Overwrite) { + fs.delete(path, true) + } + + Optional.of(createWriter(jobId, path, conf, internal)) + } + + private def createWriter( + jobId: String, path: Path, conf: Configuration, internal: Boolean): DataSourceV2Writer = { + val pathStr = path.toUri.toString + if (internal) { + new InternalRowWriter(jobId, pathStr, conf) + } else { + new Writer(jobId, pathStr, conf) + } + } +} + +class SimpleCSVReadTask(path: String, conf: SerializableConfiguration) + extends ReadTask[Row] with DataReader[Row] { + + @transient private var lines: Iterator[String] = _ + @transient private var currentLine: String = _ + @transient private var inputStream: FSDataInputStream = _ + + override def createReader(): DataReader[Row] = { + val filePath = new Path(path) + val fs = filePath.getFileSystem(conf.value) + inputStream = fs.open(filePath) + lines = new BufferedReader(new InputStreamReader(inputStream)) + .lines().iterator().asScala + this + } + + override def next(): Boolean = { + if (lines.hasNext) { + currentLine = lines.next() + true + } else { + false + } + } + + override def get(): Row = Row(currentLine.split(",").map(_.trim.toLong): _*) + + override def close(): Unit = { + inputStream.close() + } +} + +class SimpleCSVDataWriterFactory(path: String, jobId: String, conf: SerializableConfiguration) + extends DataWriterFactory[Row] { + + override def createWriter(partitionId: Int, attemptNumber: Int): DataWriter[Row] = { + val jobPath = new Path(new Path(path, "_temporary"), jobId) + val filePath = new Path(jobPath, s"$jobId-$partitionId-$attemptNumber") + val fs = filePath.getFileSystem(conf.value) + new SimpleCSVDataWriter(fs, filePath) + } +} + +class SimpleCSVDataWriter(fs: FileSystem, file: Path) extends DataWriter[Row] { + + private val out = fs.create(file) + + override def write(record: Row): Unit = { + out.writeBytes(s"${record.getLong(0)},${record.getLong(1)}\n") + } + + override def commit(): WriterCommitMessage = { + out.close() + null + } + + override def abort(): Unit = { + try { + out.close() + } finally { + fs.delete(file, false) + } + } +} + +class InternalRowCSVDataWriterFactory(path: String, jobId: String, conf: SerializableConfiguration) + extends DataWriterFactory[InternalRow] { + + override def createWriter(partitionId: Int, attemptNumber: Int): DataWriter[InternalRow] = { + val jobPath = new Path(new Path(path, "_temporary"), jobId) + val filePath = new Path(jobPath, s"$jobId-$partitionId-$attemptNumber") + val fs = filePath.getFileSystem(conf.value) + new InternalRowCSVDataWriter(fs, filePath) + } +} + +class InternalRowCSVDataWriter(fs: FileSystem, file: Path) extends DataWriter[InternalRow] { + + private val out = fs.create(file) + + override def write(record: InternalRow): Unit = { + out.writeBytes(s"${record.getLong(0)},${record.getLong(1)}\n") + } + + override def commit(): WriterCommitMessage = { + out.close() + null + } + + override def abort(): Unit = { + try { + out.close() + } finally { + fs.delete(file, false) + } + } +}