From d76e0cd5682bd953e2788f2853ee5003a4a651de Mon Sep 17 00:00:00 2001 From: Louis Chu Date: Sun, 29 Sep 2024 18:50:21 -0700 Subject: [PATCH] [REFACTOR] Move DF reformat from StatementExecutionManagerImpl to QueryResultWriterImpl (#701) * Refactor query result writer Signed-off-by: Louis Chu * Add more scala doc and update sbt Signed-off-by: Louis Chu --------- Signed-off-by: Louis Chu --- build.sbt | 2 +- .../apache/spark/sql/QueryResultWriter.scala | 35 +++++++++++++++- .../org/apache/spark/sql/CommandContext.scala | 1 - .../apache/spark/sql/FlintJobExecutor.scala | 3 +- .../org/apache/spark/sql/FlintREPL.scala | 41 +++++++++++++------ .../org/apache/spark/sql/JobOperator.scala | 1 - .../spark/sql/QueryResultWriterImpl.scala | 30 +++++++++++++- .../sql/StatementExecutionManagerImpl.scala | 15 +++---- .../org/apache/spark/sql/FlintREPLTest.scala | 13 +----- 9 files changed, 101 insertions(+), 40 deletions(-) diff --git a/build.sbt b/build.sbt index 593542e5c..73fb481a6 100644 --- a/build.sbt +++ b/build.sbt @@ -88,7 +88,7 @@ lazy val flintCore = (project in file("flint-core")) exclude ("com.fasterxml.jackson.core", "jackson-databind"), "com.amazonaws" % "aws-java-sdk-cloudwatch" % "1.12.593" exclude("com.fasterxml.jackson.core", "jackson-databind"), - "software.amazon.awssdk" % "auth-crt" % "2.25.23", + "software.amazon.awssdk" % "auth-crt" % "2.28.10" % "provided", "org.scalactic" %% "scalactic" % "3.2.15" % "test", "org.scalatest" %% "scalatest" % "3.2.15" % "test", "org.scalatest" %% "scalatest-flatspec" % "3.2.15" % "test", diff --git a/flint-commons/src/main/scala/org/apache/spark/sql/QueryResultWriter.scala b/flint-commons/src/main/scala/org/apache/spark/sql/QueryResultWriter.scala index 49dc8e355..efb001785 100644 --- a/flint-commons/src/main/scala/org/apache/spark/sql/QueryResultWriter.scala +++ b/flint-commons/src/main/scala/org/apache/spark/sql/QueryResultWriter.scala @@ -13,8 +13,39 @@ import org.opensearch.flint.common.model.FlintStatement trait QueryResultWriter { /** - * Writes the given DataFrame, which represents the result of a query execution, to an external - * data storage based on the provided FlintStatement metadata. + * Writes the given DataFrame to an external data storage based on the FlintStatement metadata. + * This method is responsible for persisting the query results. + * + * Note: This method typically involves I/O operations and may trigger Spark actions to + * materialize the DataFrame if it hasn't been processed yet. + * + * @param dataFrame + * The DataFrame containing the query results to be written. + * @param flintStatement + * The FlintStatement containing metadata that guides the writing process. */ def writeDataFrame(dataFrame: DataFrame, flintStatement: FlintStatement): Unit + + /** + * Defines transformations on the given DataFrame and triggers an action to process it. This + * method applies necessary transformations based on the FlintStatement metadata and executes an + * action to compute the result. + * + * Note: Calling this method will trigger the actual data processing in Spark. If the Spark SQL + * thread is waiting for the result of a query, termination on the same thread will be blocked + * until the action completes. + * + * @param dataFrame + * The DataFrame to be processed. + * @param flintStatement + * The FlintStatement containing statement metadata. + * @param queryStartTime + * The start time of the query execution. + * @return + * The processed DataFrame after applying transformations and executing an action. + */ + def processDataFrame( + dataFrame: DataFrame, + flintStatement: FlintStatement, + queryStartTime: Long): DataFrame } diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/CommandContext.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/CommandContext.scala index 42b1ae2f6..56bd9cb00 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/CommandContext.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/CommandContext.scala @@ -18,7 +18,6 @@ case class CommandContext( jobType: String, sessionId: String, sessionManager: SessionManager, - queryResultWriter: QueryResultWriter, queryExecutionTimeout: Duration, inactivityLimitMillis: Long, queryWaitTimeMillis: Long, diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJobExecutor.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJobExecutor.scala index 24d68fd47..c076f9974 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJobExecutor.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJobExecutor.scala @@ -12,6 +12,7 @@ import com.amazonaws.services.s3.model.AmazonS3Exception import com.fasterxml.jackson.databind.ObjectMapper import com.fasterxml.jackson.module.scala.DefaultScalaModule import org.apache.commons.text.StringEscapeUtils.unescapeJava +import org.opensearch.common.Strings import org.opensearch.flint.core.IRestHighLevelClient import org.opensearch.flint.core.logging.{CustomLogging, ExceptionMessages, OperationMessage} import org.opensearch.flint.core.metrics.MetricConstants @@ -533,7 +534,7 @@ trait FlintJobExecutor { } def instantiate[T](defaultConstructor: => T, className: String, args: Any*): T = { - if (className.isEmpty) { + if (Strings.isNullOrEmpty(className)) { defaultConstructor } else { try { diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala index a0516a37a..cdeebe663 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala @@ -169,8 +169,6 @@ object FlintREPL extends Logging with FlintJobExecutor { return } - val queryResultWriter = - instantiateQueryResultWriter(conf, sessionManager.getSessionContext) val commandContext = CommandContext( applicationId, jobId, @@ -179,7 +177,6 @@ object FlintREPL extends Logging with FlintJobExecutor { jobType, sessionId, sessionManager, - queryResultWriter, queryExecutionTimeoutSecs, inactivityLimitMillis, queryWaitTimeoutMillis, @@ -316,7 +313,7 @@ object FlintREPL extends Logging with FlintJobExecutor { // 1 thread for async query execution val threadPool = threadPoolFactory.newDaemonThreadPoolScheduledExecutor("flint-repl-query", 1) implicit val executionContext = ExecutionContext.fromExecutor(threadPool) - + val queryResultWriter = instantiateQueryResultWriter(spark, commandContext) var futurePrepareQueryExecution: Future[Either[String, Unit]] = null try { logInfo(s"""Executing session with sessionId: ${sessionId}""") @@ -342,7 +339,11 @@ object FlintREPL extends Logging with FlintJobExecutor { executionContext, lastCanPickCheckTime) val result: (Long, VerificationResult, Boolean, Long) = - processCommands(statementsExecutionManager, commandContext, commandState) + processCommands( + statementsExecutionManager, + queryResultWriter, + commandContext, + commandState) val ( updatedLastActivityTime, @@ -491,6 +492,7 @@ object FlintREPL extends Logging with FlintJobExecutor { private def processCommands( statementExecutionManager: StatementExecutionManager, + queryResultWriter: QueryResultWriter, context: CommandContext, state: CommandState): (Long, VerificationResult, Boolean, Long) = { import context._ @@ -525,6 +527,7 @@ object FlintREPL extends Logging with FlintJobExecutor { val (dataToWrite, returnedVerificationResult) = processStatementOnVerification( statementExecutionManager, + queryResultWriter, flintStatement, state, context) @@ -532,7 +535,7 @@ object FlintREPL extends Logging with FlintJobExecutor { verificationResult = returnedVerificationResult finalizeCommand( statementExecutionManager, - context, + queryResultWriter, dataToWrite, flintStatement, statementTimerContext) @@ -558,11 +561,10 @@ object FlintREPL extends Logging with FlintJobExecutor { */ private def finalizeCommand( statementExecutionManager: StatementExecutionManager, - commandContext: CommandContext, + queryResultWriter: QueryResultWriter, dataToWrite: Option[DataFrame], flintStatement: FlintStatement, statementTimerContext: Timer.Context): Unit = { - import commandContext._ try { dataToWrite.foreach(df => queryResultWriter.writeDataFrame(df, flintStatement)) if (flintStatement.isRunning || flintStatement.isWaiting) { @@ -626,6 +628,7 @@ object FlintREPL extends Logging with FlintJobExecutor { spark: SparkSession, flintStatement: FlintStatement, statementExecutionManager: StatementExecutionManager, + queryResultWriter: QueryResultWriter, dataSource: String, sessionId: String, executionContext: ExecutionContextExecutor, @@ -640,6 +643,7 @@ object FlintREPL extends Logging with FlintJobExecutor { spark, flintStatement, statementExecutionManager, + queryResultWriter, dataSource, sessionId, executionContext, @@ -677,6 +681,7 @@ object FlintREPL extends Logging with FlintJobExecutor { private def processStatementOnVerification( statementExecutionManager: StatementExecutionManager, + queryResultWriter: QueryResultWriter, flintStatement: FlintStatement, commandState: CommandState, commandContext: CommandContext) = { @@ -698,6 +703,7 @@ object FlintREPL extends Logging with FlintJobExecutor { spark, flintStatement, statementExecutionManager, + queryResultWriter, dataSource, sessionId, executionContext, @@ -764,6 +770,7 @@ object FlintREPL extends Logging with FlintJobExecutor { spark, flintStatement, statementExecutionManager, + queryResultWriter, dataSource, sessionId, executionContext, @@ -782,6 +789,7 @@ object FlintREPL extends Logging with FlintJobExecutor { spark: SparkSession, flintStatement: FlintStatement, statementsExecutionManager: StatementExecutionManager, + queryResultWriter: QueryResultWriter, dataSource: String, sessionId: String, executionContext: ExecutionContextExecutor, @@ -801,7 +809,14 @@ object FlintREPL extends Logging with FlintJobExecutor { startTime) } else { val futureQueryExecution = Future { - statementsExecutionManager.executeStatement(flintStatement) + val startTime = System.currentTimeMillis() + // Execute the statement and get the resulting DataFrame + // This step may involve Spark transformations, but not necessarily actions + val df = statementsExecutionManager.executeStatement(flintStatement) + // Process the DataFrame, applying any necessary transformations + // and triggering Spark actions to materialize the results + // This is where the actual data processing occurs + queryResultWriter.processDataFrame(df, flintStatement, startTime) }(executionContext) // time out after 10 minutes ThreadUtils.awaitResult(futureQueryExecution, queryExecutionTimeOut) @@ -998,11 +1013,11 @@ object FlintREPL extends Logging with FlintJobExecutor { } private def instantiateQueryResultWriter( - sparkConf: SparkConf, - context: Map[String, Any]): QueryResultWriter = { + spark: SparkSession, + commandContext: CommandContext): QueryResultWriter = { instantiate( - new QueryResultWriterImpl(context), - sparkConf.get(FlintSparkConf.CUSTOM_QUERY_RESULT_WRITER.key, "")) + new QueryResultWriterImpl(commandContext), + spark.conf.get(FlintSparkConf.CUSTOM_QUERY_RESULT_WRITER.key, "")) } private def recordSessionSuccess(sessionTimerContext: Timer.Context): Unit = { diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/JobOperator.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/JobOperator.scala index deee6eb1d..58d868a2e 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/JobOperator.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/JobOperator.scala @@ -61,7 +61,6 @@ case class JobOperator( jobType, "", // FlintJob doesn't have sessionId null, // FlintJob doesn't have SessionManager - null, // FlintJob doesn't have QueryResultWriter Duration.Inf, // FlintJob doesn't have queryExecutionTimeout -1, // FlintJob doesn't have inactivityLimitMillis -1, // FlintJob doesn't have queryWaitTimeMillis diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/QueryResultWriterImpl.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/QueryResultWriterImpl.scala index 23d7f42a1..61c6e0747 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/QueryResultWriterImpl.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/QueryResultWriterImpl.scala @@ -10,9 +10,14 @@ import org.opensearch.flint.common.model.FlintStatement import org.apache.spark.internal.Logging import org.apache.spark.sql.FlintJob.writeDataFrameToOpensearch import org.apache.spark.sql.flint.config.FlintSparkConf +import org.apache.spark.sql.util.CleanerFactory -class QueryResultWriterImpl(context: Map[String, Any]) extends QueryResultWriter with Logging { +class QueryResultWriterImpl(commandContext: CommandContext) + extends QueryResultWriter + with FlintJobExecutor + with Logging { + private val context = commandContext.sessionManager.getSessionContext private val resultIndex = context("resultIndex").asInstanceOf[String] // Initialize OSClient with Flint options because custom session manager implementation should not have it in the context private val osClient = new OSClient(FlintSparkConf().flintOptions()) @@ -20,4 +25,27 @@ class QueryResultWriterImpl(context: Map[String, Any]) extends QueryResultWriter override def writeDataFrame(dataFrame: DataFrame, flintStatement: FlintStatement): Unit = { writeDataFrameToOpensearch(dataFrame, resultIndex, osClient) } + + override def processDataFrame( + dataFrame: DataFrame, + statement: FlintStatement, + queryStartTime: Long): DataFrame = { + import commandContext._ + + /** + * Reformat the given DataFrame to the desired format for OpenSearch storage. + */ + getFormattedData( + applicationId, + jobId, + dataFrame, + spark, + dataSource, + statement.queryId, + statement.query, + sessionId, + queryStartTime, + currentTimeProvider, + CleanerFactory.cleaner(false)) + } } diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/StatementExecutionManagerImpl.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/StatementExecutionManagerImpl.scala index 4e9435f7b..432d6df11 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/StatementExecutionManagerImpl.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/StatementExecutionManagerImpl.scala @@ -54,16 +54,13 @@ class StatementExecutionManagerImpl(commandContext: CommandContext) } override def executeStatement(statement: FlintStatement): DataFrame = { - import commandContext._ - executeQuery( - applicationId, - jobId, - spark, - statement.query, - dataSource, + import commandContext.spark + // we have to set job group in the same thread that started the query according to spark doc + spark.sparkContext.setJobGroup( statement.queryId, - sessionId, - false) + "Job group for " + statement.queryId, + interruptOnCancel = true) + spark.sql(statement.query) } private def createOpenSearchQueryReader() = { diff --git a/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala b/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala index 355bd9ede..5eeccce73 100644 --- a/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala +++ b/spark-sql-application/src/test/scala/org/apache/spark/sql/FlintREPLTest.scala @@ -675,7 +675,6 @@ class FlintREPLTest INTERACTIVE_JOB_TYPE, sessionId, sessionManager, - queryResultWriter, Duration(10, MINUTES), 60, 60, @@ -748,7 +747,6 @@ class FlintREPLTest INTERACTIVE_JOB_TYPE, sessionId, sessionManager, - queryResultWriter, Duration(10, MINUTES), 60, 60, @@ -761,6 +759,7 @@ class FlintREPLTest mockSparkSession, flintStatement, statementExecutionManager, + queryResultWriter, dataSource, sessionId, executionContext, @@ -809,7 +808,6 @@ class FlintREPLTest when(mockSparkSession.sparkContext).thenReturn(sparkContext) // Assume handleQueryException logs the error and returns an error message string - val mockErrorString = "Error due to syntax" when(mockSparkSession.createDataFrame(any[Seq[Product]])(any[TypeTag[Product]])) .thenReturn(expectedDataFrame) when(expectedDataFrame.toDF(any[Seq[String]]: _*)).thenReturn(expectedDataFrame) @@ -824,7 +822,6 @@ class FlintREPLTest INTERACTIVE_JOB_TYPE, sessionId, sessionManager, - queryResultWriter, Duration(10, MINUTES), 60, 60, @@ -837,6 +834,7 @@ class FlintREPLTest mockSparkSession, flintStatement, statementExecutionManager, + queryResultWriter, dataSource, sessionId, executionContext, @@ -1076,7 +1074,6 @@ class FlintREPLTest INTERACTIVE_JOB_TYPE, sessionId, sessionManager, - queryResultWriter, Duration(10, MINUTES), shortInactivityLimit, 60, @@ -1146,7 +1143,6 @@ class FlintREPLTest INTERACTIVE_JOB_TYPE, sessionId, sessionManager, - queryResultWriter, Duration(10, MINUTES), longInactivityLimit, 60, @@ -1212,7 +1208,6 @@ class FlintREPLTest INTERACTIVE_JOB_TYPE, sessionId, sessionManager, - queryResultWriter, Duration(10, MINUTES), inactivityLimit, 60, @@ -1283,7 +1278,6 @@ class FlintREPLTest INTERACTIVE_JOB_TYPE, sessionId, sessionManager, - queryResultWriter, Duration(10, MINUTES), inactivityLimit, 60, @@ -1367,7 +1361,6 @@ class FlintREPLTest override val osClient: OSClient = mockOSClient override lazy val flintSessionIndexUpdater: OpenSearchUpdater = mockOpenSearchUpdater } - val queryResultWriter = mock[QueryResultWriter] val commandContext = CommandContext( applicationId, @@ -1377,7 +1370,6 @@ class FlintREPLTest INTERACTIVE_JOB_TYPE, sessionId, sessionManager, - queryResultWriter, Duration(10, MINUTES), inactivityLimit, 60, @@ -1453,7 +1445,6 @@ class FlintREPLTest INTERACTIVE_JOB_TYPE, sessionId, sessionManager, - queryResultWriter, Duration(10, MINUTES), inactivityLimit, 60,