diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala index 43b95766882f..87f20e1d780d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala @@ -1149,7 +1149,9 @@ object StateStore extends Logging { threadPool.shutdownNow() // Cancel currently executing tasks // Wait a while for tasks to respond to being cancelled - if (!threadPool.awaitTermination(60, TimeUnit.SECONDS)) { + // To avoid long test times, use minimum of timeout value or 60 seconds + if (!threadPool.awaitTermination(Math.min(60, shutdownTimeout), + TimeUnit.SECONDS)) { logError("MaintenanceThreadPool did not terminate") } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala index 6948aedd5640..c72d88709d25 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala @@ -294,11 +294,12 @@ class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext { query.processAllAvailable() inputData2.addData(1, 2, 3) query2.processAllAvailable() - Thread.sleep(500) } - // Verify only the partitions in badPartitions doesn't have a snapshot - verifySnapshotUploadEvents(coordRef, query, badPartitions) - verifySnapshotUploadEvents(coordRef, query2, badPartitions) + eventually(timeout(5.seconds)) { + // Verify only the partitions in badPartitions doesn't have a snapshot + verifySnapshotUploadEvents(coordRef, query, badPartitions) + verifySnapshotUploadEvents(coordRef, query2, badPartitions) + } def verifyShouldForceSnapshotOnBadPartitions( checkpointDir: String, @@ -335,15 +336,16 @@ class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext { // commit should automatically trigger snapshot inputData.addData(1, 2, 3) query.processAllAvailable() - Thread.sleep(500) // Verify that snapshot was created and shouldForceSnapshotUpload is now false - verifyShouldForceSnapshotOnBadPartitions( - stateCheckpointDir, - query.runId, - shouldForce = false, - expectedSnapshotVersion = Some(snapshotVersionOnLagDetected + 1) - ) + eventually(timeout(5.seconds)) { + verifyShouldForceSnapshotOnBadPartitions( + stateCheckpointDir, + query.runId, + shouldForce = false, + expectedSnapshotVersion = Some(snapshotVersionOnLagDetected + 1) + ) + } val streamingQuery2 = query2.asInstanceOf[StreamingQueryWrapper].streamingQuery val stateCheckpointDir2 = streamingQuery2.lastExecution.checkpointLocation @@ -392,19 +394,20 @@ class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext { (0 until snapshotVersionOnLagDetected).foreach { _ => inputData.addData(1, 2, 3) query.processAllAvailable() - Thread.sleep(500) } // Verify only the partitions in badPartitions doesn't have a snapshot - val laggingStores = coordRef.getLaggingStoresForTesting( - query.runId, - snapshotVersionOnLagDetected + 1) - assert(laggingStores.size == badPartitions.size) - assert(laggingStores.map(_.storeId.partitionId).toSet == badPartitions) + eventually(timeout(5.seconds)) { + val laggingStores = coordRef.getLaggingStoresForTesting( + query.runId, + snapshotVersionOnLagDetected + 1) + assert(laggingStores.size == badPartitions.size) + assert(laggingStores.map(_.storeId.partitionId).toSet == badPartitions) + } // The coordinator should detected that lagging stores now. So next // commit should automatically trigger snapshot inputData.addData(1, 2, 3) query.processAllAvailable() - Thread.sleep(500) + val streamingQuery = query.asInstanceOf[StreamingQueryWrapper].streamingQuery val stateCheckpointDir = streamingQuery.lastExecution.checkpointLocation val storeId = StateStoreId( @@ -414,12 +417,14 @@ class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext { StateStoreId.DEFAULT_STORE_NAME ) val providerId = StateStoreProviderId(storeId, query.runId) - val latestSnapshotVersion = coordRef.getLatestSnapshotVersionForTesting(providerId) - assert(latestSnapshotVersion.get == snapshotVersionOnLagDetected + 1) - val laggingStores2 = coordRef.getLaggingStoresForTesting( - query.runId, - snapshotVersionOnLagDetected + 1) - assert(laggingStores2.isEmpty) + eventually(timeout(5.seconds)) { + val latestSnapshotVersion = coordRef.getLatestSnapshotVersionForTesting(providerId) + assert(latestSnapshotVersion.get == snapshotVersionOnLagDetected + 1) + val laggingStores2 = coordRef.getLaggingStoresForTesting( + query.runId, + snapshotVersionOnLagDetected + 1) + assert(laggingStores2.isEmpty) + } query.stop() } } @@ -430,7 +435,7 @@ class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext { ) { withCoordinatorAndSQLConf( sc, - SQLConf.SHUFFLE_PARTITIONS.key -> "5", + SQLConf.SHUFFLE_PARTITIONS.key -> "3", SQLConf.STREAMING_MAINTENANCE_INTERVAL.key -> "100", SQLConf.STATE_STORE_MAINTENANCE_SHUTDOWN_TIMEOUT.key -> "3", SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT.key -> "1", @@ -446,13 +451,14 @@ class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext { val inputData = MemoryStream[Int] val query = setUpStatefulQuery(inputData, "query") // Add, commit, and wait multiple times to force snapshot versions and time difference - (0 until 6).foreach { _ => + (0 until 4).foreach { _ => inputData.addData(1, 2, 3) query.processAllAvailable() - Thread.sleep(500) } // Verify only the partitions in badPartitions are marked as lagging - verifySnapshotUploadEvents(coordRef, query, badPartitions) + eventually(timeout(5.seconds)) { + verifySnapshotUploadEvents(coordRef, query, badPartitions) + } query.stop() } } @@ -492,14 +498,15 @@ class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext { .option("checkpointLocation", checkpointLocation.toString) .start() // Add, commit, and wait multiple times to force snapshot versions and time difference - (0 until 7).foreach { _ => + (0 until 4).foreach { _ => input1.addData(1, 5) input2.addData(1, 5, 10) query.processAllAvailable() - Thread.sleep(500) } // Verify only the partitions in badPartitions are marked as lagging - verifySnapshotUploadEvents(coordRef, query, badPartitions, allJoinStateStoreNames) + eventually(timeout(5.seconds)) { + verifySnapshotUploadEvents(coordRef, query, badPartitions, allJoinStateStoreNames) + } query.stop() } } @@ -537,35 +544,40 @@ class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext { // Process twice the amount of data for the first query input1.addData(1, 2, 3) query1.processAllAvailable() - Thread.sleep(1000) } // Verify that the coordinator logged the correct lagging stores for the first query val streamingQuery1 = query1.asInstanceOf[StreamingQueryWrapper].streamingQuery - val latestVersion1 = streamingQuery1.lastProgress.batchId + 1 - val laggingStores1 = coordRef.getLaggingStoresForTesting(query1.runId, latestVersion1) - assert(laggingStores1.size == 2) - assert(laggingStores1.forall(_.storeId.partitionId <= 1)) - assert(laggingStores1.forall(_.queryRunId == query1.runId)) + eventually(timeout(5.seconds)) { + val latestVersion1 = streamingQuery1.lastProgress.batchId + 1 + val laggingStores1 = coordRef.getLaggingStoresForTesting(query1.runId, latestVersion1) + + assert(laggingStores1.size == 2) + assert(laggingStores1.forall(_.storeId.partitionId <= 1)) + assert(laggingStores1.forall(_.queryRunId == query1.runId)) + } // Verify that the second query run hasn't reported anything yet due to lack of data val streamingQuery2 = query2.asInstanceOf[StreamingQueryWrapper].streamingQuery - var latestVersion2 = streamingQuery2.lastProgress.batchId + 1 - var laggingStores2 = coordRef.getLaggingStoresForTesting(query2.runId, latestVersion2) - assert(laggingStores2.isEmpty) + eventually(timeout(5.seconds)) { + var latestVersion2 = streamingQuery2.lastProgress.batchId + 1 + var laggingStores2 = coordRef.getLaggingStoresForTesting(query2.runId, latestVersion2) + assert(laggingStores2.isEmpty) + } // Process some more data for the second query to force lag reports input2.addData(1, 2, 3) query2.processAllAvailable() - Thread.sleep(500) // Verify that the coordinator logged the correct lagging stores for the second query - latestVersion2 = streamingQuery2.lastProgress.batchId + 1 - laggingStores2 = coordRef.getLaggingStoresForTesting(query2.runId, latestVersion2) + eventually(timeout(5.seconds)) { + val latestVersion2 = streamingQuery2.lastProgress.batchId + 1 + val laggingStores2 = coordRef.getLaggingStoresForTesting(query2.runId, latestVersion2) - assert(laggingStores2.size == 2) - assert(laggingStores2.forall(_.storeId.partitionId <= 1)) - assert(laggingStores2.forall(_.queryRunId == query2.runId)) + assert(laggingStores2.size == 2) + assert(laggingStores2.forall(_.storeId.partitionId <= 1)) + assert(laggingStores2.forall(_.queryRunId == query2.runId)) + } } } @@ -600,16 +612,14 @@ class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext { inputData.addData(1, 2, 3) query.processAllAvailable() - // Sleep for the duration of a maintenance interval - which should be enough - // to pass the time check for lagging stores. - Thread.sleep(100) - - val latestVersion = - query.asInstanceOf[StreamingQueryWrapper].streamingQuery.lastProgress.batchId + 1 - // Verify that no instances are marked as lagging, even when upload messages are sent. - // Since snapshot uploads are tied to commit, the lack of version difference should prevent - // the stores from being marked as lagging. - assert(coordRef.getLaggingStoresForTesting(query.runId, latestVersion).isEmpty) + eventually(timeout(5.seconds)) { + val latestVersion = + query.asInstanceOf[StreamingQueryWrapper].streamingQuery.lastProgress.batchId + 1 + // Verify that no instances are marked as lagging, even when upload messages are sent. + // Since snapshot uploads are tied to commit, the lack of version difference should + // prevent the stores from being marked as lagging. + assert(coordRef.getLaggingStoresForTesting(query.runId, latestVersion).isEmpty) + } query.stop() } } @@ -641,12 +651,14 @@ class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext { (0 until 3).foreach { _ => inputData.addData(1, 2, 3) query.processAllAvailable() - Thread.sleep(500) } - val latestVersion = - query.asInstanceOf[StreamingQueryWrapper].streamingQuery.lastProgress.batchId + 1 - // Verify that all instances are marked as lagging, since no upload messages are being sent - assert(coordRef.getLaggingStoresForTesting(query.runId, latestVersion).size == 2) + eventually(timeout(5.seconds)) { + val latestVersion = + query.asInstanceOf[StreamingQueryWrapper].streamingQuery.lastProgress.batchId + 1 + // Verify that all instances are marked as lagging, since no upload messages are being + // sent + assert(coordRef.getLaggingStoresForTesting(query.runId, latestVersion).size == 2) + } query.stop() } }