diff --git a/spark/src/main/scala/org/apache/comet/serde/operator/CometDataWritingCommand.scala b/spark/src/main/scala/org/apache/comet/serde/operator/CometDataWritingCommand.scala index 31575138f8..a03a150568 100644 --- a/spark/src/main/scala/org/apache/comet/serde/operator/CometDataWritingCommand.scala +++ b/spark/src/main/scala/org/apache/comet/serde/operator/CometDataWritingCommand.scala @@ -132,7 +132,7 @@ object CometDataWritingCommand extends CometOperatorSerde[DataWritingCommandExec .newBuilder() .setOutputPath(outputPath) .setCompression(codec) - .addAllColumnNames(cmd.query.output.map(_.name).asJava) + .addAllColumnNames(cmd.outputColumnNames.asJava) // Note: work_dir, job_id, and task_attempt_id will be set at execution time // in CometNativeWriteExec, as they depend on the Spark task context @@ -201,7 +201,7 @@ object CometDataWritingCommand extends CometOperatorSerde[DataWritingCommandExec throw new SparkException(s"Could not instantiate FileCommitProtocol: ${e.getMessage}") } - CometNativeWriteExec(nativeOp, childPlan, outputPath, committer, jobId) + CometNativeWriteExec(nativeOp, childPlan, outputPath, committer, jobId, cmd.catalogTable) } private def parseCompressionCodec(cmd: InsertIntoHadoopFsRelationCommand) = { diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometNativeWriteExec.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometNativeWriteExec.scala index 39e7ac6eef..0cb3741f45 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/CometNativeWriteExec.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometNativeWriteExec.scala @@ -27,6 +27,7 @@ import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl import org.apache.spark.internal.io.FileCommitProtocol import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.catalog.CatalogTable import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode} import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} import org.apache.spark.sql.vectorized.ColumnarBatch @@ -63,7 +64,8 @@ case class CometNativeWriteExec( child: SparkPlan, outputPath: String, committer: Option[FileCommitProtocol] = None, - jobTrackerID: String = Utils.createTempDir().getName) + jobTrackerID: String = Utils.createTempDir().getName, + catalogTable: Option[CatalogTable] = None) extends CometNativeExec with UnaryExecNode { @@ -135,6 +137,11 @@ case class CometNativeWriteExec( } } + // Refresh the catalog table cache so subsequent reads see the new data + catalogTable.foreach { ct => + session.catalog.refreshTable(ct.identifier.quotedString) + } + // Return empty RDD as write operations don't return data sparkContext.emptyRDD[InternalRow] } diff --git a/spark/src/test/scala/org/apache/comet/parquet/CometParquetWriterSuite.scala b/spark/src/test/scala/org/apache/comet/parquet/CometParquetWriterSuite.scala index e4c405c003..9e4c96854f 100644 --- a/spark/src/test/scala/org/apache/comet/parquet/CometParquetWriterSuite.scala +++ b/spark/src/test/scala/org/apache/comet/parquet/CometParquetWriterSuite.scala @@ -377,6 +377,167 @@ class CometParquetWriterSuite extends CometTestBase { } } + private def withNativeWriteConf(f: => Unit): Unit = { + withSQLConf( + CometConf.COMET_NATIVE_PARQUET_WRITE_ENABLED.key -> "true", + CometConf.COMET_EXEC_ENABLED.key -> "true", + CometConf.getOperatorAllowIncompatConfigKey(classOf[DataWritingCommandExec]) -> "true") { + f + } + } + + private def assertCometNativeWrite(insertSql: String): Unit = { + val plan = captureSqlWritePlan(insertSql) + val hasNativeWrite = plan.exists { + case _: CometNativeWriteExec => true + case d: DataWritingCommandExec => + d.child.exists(_.isInstanceOf[CometNativeWriteExec]) + case _ => false + } + assert( + hasNativeWrite, + s"Expected CometNativeWriteExec in plan, but not found:\n${plan.treeString}") + } + + private def captureSqlWritePlan(sqlText: String): SparkPlan = { + var capturedPlan: Option[QueryExecution] = None + + val listener = new org.apache.spark.sql.util.QueryExecutionListener { + override def onSuccess(funcName: String, qe: QueryExecution, durationNs: Long): Unit = { + if (funcName == "command") { + capturedPlan = Some(qe) + } + } + override def onFailure( + funcName: String, + qe: QueryExecution, + exception: Exception): Unit = {} + } + + spark.listenerManager.register(listener) + try { + sql(sqlText) + val maxWaitTimeMs = 5000 + val checkIntervalMs = 50 + var iterations = 0 + while (capturedPlan.isEmpty && iterations < maxWaitTimeMs / checkIntervalMs) { + Thread.sleep(checkIntervalMs) + iterations += 1 + } + assert(capturedPlan.isDefined, s"Failed to capture plan for: $sqlText") + stripAQEPlan(capturedPlan.get.executedPlan) + } finally { + spark.listenerManager.unregister(listener) + } + } + + // SPARK-38811 INSERT INTO on columns added with ALTER TABLE ADD COLUMNS: Positive tests + // Mirrors the Spark InsertSuite test to validate Comet native writer compatibility. + + test("SPARK-38811: simple default value with concat expression") { + withNativeWriteConf { + withTable("t") { + sql("create table t(i boolean) using parquet") + sql("alter table t add column s string default concat('abc', 'def')") + assertCometNativeWrite("insert into t values(true, default)") + checkAnswer(spark.table("t"), Row(true, "abcdef")) + } + } + } + + test("SPARK-38811: multiple trailing default values") { + withNativeWriteConf { + withTable("t") { + sql("create table t(i int) using parquet") + sql("alter table t add column s bigint default 42") + sql("alter table t add column x bigint default 43") + assertCometNativeWrite("insert into t(i) values(1)") + checkAnswer(spark.table("t"), Row(1, 42, 43)) + } + } + } + + test("SPARK-38811: multiple trailing defaults via add columns") { + withNativeWriteConf { + withTable("t") { + sql("create table t(i int) using parquet") + sql("alter table t add columns s bigint default 42, x bigint default 43") + assertCometNativeWrite("insert into t(i) values(1)") + checkAnswer(spark.table("t"), Row(1, 42, 43)) + } + } + } + + test("SPARK-38811: default with nullable column (no default)") { + withNativeWriteConf { + withTable("t") { + sql("create table t(i int) using parquet") + sql("alter table t add column s bigint default 42") + sql("alter table t add column x bigint") + assertCometNativeWrite("insert into t(i) values(1)") + checkAnswer(spark.table("t"), Row(1, 42, null)) + } + } + } + + test("SPARK-38811: expression default (41 + 1)") { + withNativeWriteConf { + withTable("t") { + sql("create table t(i boolean) using parquet") + sql("alter table t add column s bigint default 41 + 1") + assertCometNativeWrite("insert into t(i) values(default)") + checkAnswer(spark.table("t"), Row(null, 42)) + } + } + } + + test("SPARK-38811: explicit defaults in multiple positions") { + withNativeWriteConf { + withTable("t") { + sql("create table t(i boolean default false) using parquet") + sql("alter table t add column s bigint default 42") + assertCometNativeWrite("insert into t values(false, default), (default, 42)") + checkAnswer(spark.table("t"), Seq(Row(false, 42), Row(false, 42))) + } + } + } + + test("SPARK-38811: default with alias over VALUES") { + withNativeWriteConf { + withTable("t") { + sql("create table t(i boolean) using parquet") + sql("alter table t add column s bigint default 42") + assertCometNativeWrite( + "insert into t select * from values (false, default) as tab(col, other)") + checkAnswer(spark.table("t"), Row(false, 42)) + } + } + } + + test("SPARK-38811: default value in wrong order evaluates to NULL") { + withNativeWriteConf { + withTable("t") { + sql("create table t(i boolean) using parquet") + sql("alter table t add column s bigint default 42") + assertCometNativeWrite("insert into t values (default, 43)") + checkAnswer(spark.table("t"), Row(null, 43)) + } + } + } + + // INSERT INTO ... SELECT with native writer fails, + // open issue: https://github.com/apache/datafusion-comet/issues/3521 + ignore("SPARK-38811: default via SELECT statement") { + withNativeWriteConf { + withTable("t") { + sql("create table t(i boolean) using parquet") + sql("alter table t add column s bigint default 42") + sql("insert into t select false, default") + checkAnswer(spark.table("t"), Row(false, 42)) + } + } + } + private def createTestData(inputDir: File): String = { val inputPath = new File(inputDir, "input.parquet").getAbsolutePath val schema = FuzzDataGenerator.generateSchema(