diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/CommandUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/CommandUtils.scala index dea1e017b2e5..70e7cd9a1e40 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/CommandUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/CommandUtils.scala @@ -42,16 +42,14 @@ object CommandUtils extends Logging { /** Change statistics after changing data by commands. */ def updateTableStats(sparkSession: SparkSession, table: CatalogTable): Unit = { - if (table.stats.nonEmpty) { - val catalog = sparkSession.sessionState.catalog - if (sparkSession.sessionState.conf.autoSizeUpdateEnabled) { - val newTable = catalog.getTableMetadata(table.identifier) - val newSize = CommandUtils.calculateTotalSize(sparkSession, newTable) - val newStats = CatalogStatistics(sizeInBytes = newSize) - catalog.alterTableStats(table.identifier, Some(newStats)) - } else { - catalog.alterTableStats(table.identifier, None) - } + val catalog = sparkSession.sessionState.catalog + if (sparkSession.sessionState.conf.autoSizeUpdateEnabled) { + val newTable = catalog.getTableMetadata(table.identifier) + val newSize = CommandUtils.calculateTotalSize(sparkSession, newTable) + val newStats = CatalogStatistics(sizeInBytes = newSize) + catalog.alterTableStats(table.identifier, Some(newStats)) + } else if (table.stats.nonEmpty) { + catalog.alterTableStats(table.identifier, None) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala index d071efb804ad..c4ba12289815 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala @@ -337,6 +337,26 @@ class StatisticsCollectionSuite extends StatisticsCollectionTestBase with Shared } } + test("auto gather stats after insert command") { + val table = "change_stats_insert_datasource_table" + Seq(false, true).foreach { autoUpdate => + withSQLConf(SQLConf.AUTO_SIZE_UPDATE_ENABLED.key -> autoUpdate.toString) { + withTable(table) { + sql(s"CREATE TABLE $table (i int, j string) USING PARQUET") + // insert into command + sql(s"INSERT INTO TABLE $table SELECT 1, 'abc'") + val stats = getCatalogTable(table).stats + if (autoUpdate) { + assert(stats.isDefined) + assert(stats.get.sizeInBytes >= 0) + } else { + assert(stats.isEmpty) + } + } + } + } + } + test("invalidation of tableRelationCache after inserts") { val table = "invalidate_catalog_cache_table" Seq(false, true).foreach { autoUpdate =>