Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[REFACTOR] Move DF reformat from StatementExecutionManagerImpl to QueryResultWriterImpl #701

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ lazy val flintCore = (project in file("flint-core"))
exclude ("com.fasterxml.jackson.core", "jackson-databind"),
"com.amazonaws" % "aws-java-sdk-cloudwatch" % "1.12.593"
exclude("com.fasterxml.jackson.core", "jackson-databind"),
"software.amazon.awssdk" % "auth-crt" % "2.25.23",
"software.amazon.awssdk" % "auth-crt" % "2.28.10" % "provided",
"org.scalactic" %% "scalactic" % "3.2.15" % "test",
"org.scalatest" %% "scalatest" % "3.2.15" % "test",
"org.scalatest" %% "scalatest-flatspec" % "3.2.15" % "test",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,39 @@ import org.opensearch.flint.common.model.FlintStatement
trait QueryResultWriter {

/**
* Writes the given DataFrame, which represents the result of a query execution, to an external
* data storage based on the provided FlintStatement metadata.
* Writes the given DataFrame to an external data storage based on the FlintStatement metadata.
* This method is responsible for persisting the query results.
*
* Note: This method typically involves I/O operations and may trigger Spark actions to
* materialize the DataFrame if it hasn't been processed yet.
*
* @param dataFrame
* The DataFrame containing the query results to be written.
* @param flintStatement
* The FlintStatement containing metadata that guides the writing process.
*/
def writeDataFrame(dataFrame: DataFrame, flintStatement: FlintStatement): Unit

/**
* Defines transformations on the given DataFrame and triggers an action to process it. This
* method applies necessary transformations based on the FlintStatement metadata and executes an
* action to compute the result.
*
* Note: Calling this method will trigger the actual data processing in Spark. If the Spark SQL
* thread is waiting for the result of a query, termination on the same thread will be blocked
* until the action completes.
*
* @param dataFrame
* The DataFrame to be processed.
* @param flintStatement
* The FlintStatement containing statement metadata.
* @param queryStartTime
* The start time of the query execution.
* @return
* The processed DataFrame after applying transformations and executing an action.
*/
def processDataFrame(
dataFrame: DataFrame,
flintStatement: FlintStatement,
queryStartTime: Long): DataFrame
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ case class CommandContext(
jobType: String,
sessionId: String,
sessionManager: SessionManager,
queryResultWriter: QueryResultWriter,
queryExecutionTimeout: Duration,
inactivityLimitMillis: Long,
queryWaitTimeMillis: Long,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import com.amazonaws.services.s3.model.AmazonS3Exception
import com.fasterxml.jackson.databind.ObjectMapper
import com.fasterxml.jackson.module.scala.DefaultScalaModule
import org.apache.commons.text.StringEscapeUtils.unescapeJava
import org.opensearch.common.Strings
import org.opensearch.flint.core.IRestHighLevelClient
import org.opensearch.flint.core.logging.{CustomLogging, ExceptionMessages, OperationMessage}
import org.opensearch.flint.core.metrics.MetricConstants
Expand Down Expand Up @@ -533,7 +534,7 @@ trait FlintJobExecutor {
}

def instantiate[T](defaultConstructor: => T, className: String, args: Any*): T = {
if (className.isEmpty) {
if (Strings.isNullOrEmpty(className)) {
defaultConstructor
} else {
try {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -169,8 +169,6 @@ object FlintREPL extends Logging with FlintJobExecutor {
return
}

val queryResultWriter =
instantiateQueryResultWriter(conf, sessionManager.getSessionContext)
val commandContext = CommandContext(
applicationId,
jobId,
Expand All @@ -179,7 +177,6 @@ object FlintREPL extends Logging with FlintJobExecutor {
jobType,
sessionId,
sessionManager,
queryResultWriter,
queryExecutionTimeoutSecs,
inactivityLimitMillis,
queryWaitTimeoutMillis,
Expand Down Expand Up @@ -316,7 +313,7 @@ object FlintREPL extends Logging with FlintJobExecutor {
// 1 thread for async query execution
val threadPool = threadPoolFactory.newDaemonThreadPoolScheduledExecutor("flint-repl-query", 1)
implicit val executionContext = ExecutionContext.fromExecutor(threadPool)

val queryResultWriter = instantiateQueryResultWriter(spark, commandContext)
var futurePrepareQueryExecution: Future[Either[String, Unit]] = null
try {
logInfo(s"""Executing session with sessionId: ${sessionId}""")
Expand All @@ -342,7 +339,11 @@ object FlintREPL extends Logging with FlintJobExecutor {
executionContext,
lastCanPickCheckTime)
val result: (Long, VerificationResult, Boolean, Long) =
processCommands(statementsExecutionManager, commandContext, commandState)
processCommands(
statementsExecutionManager,
queryResultWriter,
commandContext,
commandState)

val (
updatedLastActivityTime,
Expand Down Expand Up @@ -491,6 +492,7 @@ object FlintREPL extends Logging with FlintJobExecutor {

private def processCommands(
statementExecutionManager: StatementExecutionManager,
queryResultWriter: QueryResultWriter,
context: CommandContext,
state: CommandState): (Long, VerificationResult, Boolean, Long) = {
import context._
Expand Down Expand Up @@ -525,14 +527,15 @@ object FlintREPL extends Logging with FlintJobExecutor {
val (dataToWrite, returnedVerificationResult) =
processStatementOnVerification(
statementExecutionManager,
queryResultWriter,
flintStatement,
state,
context)

verificationResult = returnedVerificationResult
finalizeCommand(
statementExecutionManager,
context,
queryResultWriter,
dataToWrite,
flintStatement,
statementTimerContext)
Expand All @@ -558,11 +561,10 @@ object FlintREPL extends Logging with FlintJobExecutor {
*/
private def finalizeCommand(
statementExecutionManager: StatementExecutionManager,
commandContext: CommandContext,
queryResultWriter: QueryResultWriter,
dataToWrite: Option[DataFrame],
flintStatement: FlintStatement,
statementTimerContext: Timer.Context): Unit = {
import commandContext._
try {
dataToWrite.foreach(df => queryResultWriter.writeDataFrame(df, flintStatement))
if (flintStatement.isRunning || flintStatement.isWaiting) {
Expand Down Expand Up @@ -626,6 +628,7 @@ object FlintREPL extends Logging with FlintJobExecutor {
spark: SparkSession,
flintStatement: FlintStatement,
statementExecutionManager: StatementExecutionManager,
queryResultWriter: QueryResultWriter,
dataSource: String,
sessionId: String,
executionContext: ExecutionContextExecutor,
Expand All @@ -640,6 +643,7 @@ object FlintREPL extends Logging with FlintJobExecutor {
spark,
flintStatement,
statementExecutionManager,
queryResultWriter,
dataSource,
sessionId,
executionContext,
Expand Down Expand Up @@ -677,6 +681,7 @@ object FlintREPL extends Logging with FlintJobExecutor {

private def processStatementOnVerification(
statementExecutionManager: StatementExecutionManager,
queryResultWriter: QueryResultWriter,
flintStatement: FlintStatement,
commandState: CommandState,
commandContext: CommandContext) = {
Expand All @@ -698,6 +703,7 @@ object FlintREPL extends Logging with FlintJobExecutor {
spark,
flintStatement,
statementExecutionManager,
queryResultWriter,
dataSource,
sessionId,
executionContext,
Expand Down Expand Up @@ -764,6 +770,7 @@ object FlintREPL extends Logging with FlintJobExecutor {
spark,
flintStatement,
statementExecutionManager,
queryResultWriter,
dataSource,
sessionId,
executionContext,
Expand All @@ -782,6 +789,7 @@ object FlintREPL extends Logging with FlintJobExecutor {
spark: SparkSession,
flintStatement: FlintStatement,
statementsExecutionManager: StatementExecutionManager,
queryResultWriter: QueryResultWriter,
dataSource: String,
sessionId: String,
executionContext: ExecutionContextExecutor,
Expand All @@ -801,7 +809,14 @@ object FlintREPL extends Logging with FlintJobExecutor {
startTime)
} else {
val futureQueryExecution = Future {
statementsExecutionManager.executeStatement(flintStatement)
val startTime = System.currentTimeMillis()
// Execute the statement and get the resulting DataFrame
// This step may involve Spark transformations, but not necessarily actions
val df = statementsExecutionManager.executeStatement(flintStatement)
// Process the DataFrame, applying any necessary transformations
// and triggering Spark actions to materialize the results
// This is where the actual data processing occurs
queryResultWriter.processDataFrame(df, flintStatement, startTime)
}(executionContext)
// time out after 10 minutes
ThreadUtils.awaitResult(futureQueryExecution, queryExecutionTimeOut)
Expand Down Expand Up @@ -998,11 +1013,11 @@ object FlintREPL extends Logging with FlintJobExecutor {
}

private def instantiateQueryResultWriter(
sparkConf: SparkConf,
context: Map[String, Any]): QueryResultWriter = {
spark: SparkSession,
commandContext: CommandContext): QueryResultWriter = {
instantiate(
new QueryResultWriterImpl(context),
sparkConf.get(FlintSparkConf.CUSTOM_QUERY_RESULT_WRITER.key, ""))
new QueryResultWriterImpl(commandContext),
spark.conf.get(FlintSparkConf.CUSTOM_QUERY_RESULT_WRITER.key, ""))
}

private def recordSessionSuccess(sessionTimerContext: Timer.Context): Unit = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,6 @@ case class JobOperator(
jobType,
"", // FlintJob doesn't have sessionId
null, // FlintJob doesn't have SessionManager
null, // FlintJob doesn't have QueryResultWriter
Duration.Inf, // FlintJob doesn't have queryExecutionTimeout
-1, // FlintJob doesn't have inactivityLimitMillis
-1, // FlintJob doesn't have queryWaitTimeMillis
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,42 @@ import org.opensearch.flint.common.model.FlintStatement
import org.apache.spark.internal.Logging
import org.apache.spark.sql.FlintJob.writeDataFrameToOpensearch
import org.apache.spark.sql.flint.config.FlintSparkConf
import org.apache.spark.sql.util.CleanerFactory

class QueryResultWriterImpl(context: Map[String, Any]) extends QueryResultWriter with Logging {
class QueryResultWriterImpl(commandContext: CommandContext)
extends QueryResultWriter
with FlintJobExecutor
with Logging {

private val context = commandContext.sessionManager.getSessionContext
private val resultIndex = context("resultIndex").asInstanceOf[String]
// Initialize OSClient with Flint options because custom session manager implementation should not have it in the context
private val osClient = new OSClient(FlintSparkConf().flintOptions())

override def writeDataFrame(dataFrame: DataFrame, flintStatement: FlintStatement): Unit = {
writeDataFrameToOpensearch(dataFrame, resultIndex, osClient)
}

override def processDataFrame(
dataFrame: DataFrame,
statement: FlintStatement,
queryStartTime: Long): DataFrame = {
import commandContext._

/**
* Reformat the given DataFrame to the desired format for OpenSearch storage.
*/
getFormattedData(
applicationId,
jobId,
dataFrame,
spark,
dataSource,
statement.queryId,
statement.query,
sessionId,
queryStartTime,
currentTimeProvider,
CleanerFactory.cleaner(false))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -54,16 +54,13 @@ class StatementExecutionManagerImpl(commandContext: CommandContext)
}

override def executeStatement(statement: FlintStatement): DataFrame = {
import commandContext._
executeQuery(
applicationId,
jobId,
spark,
statement.query,
dataSource,
import commandContext.spark
// we have to set job group in the same thread that started the query according to spark doc
spark.sparkContext.setJobGroup(
statement.queryId,
sessionId,
false)
"Job group for " + statement.queryId,
interruptOnCancel = true)
spark.sql(statement.query)
}

private def createOpenSearchQueryReader() = {
Expand Down
Loading
Loading