From 23315a02f82b71df9db2a5667939ab4c7c65bd24 Mon Sep 17 00:00:00 2001 From: Alex Qyoun-ae <4062971+MazterQyou@users.noreply.github.com> Date: Wed, 9 Oct 2024 17:09:21 +0400 Subject: [PATCH] chore(cubesql): Do not call async Node functions while planning --- .../cubesql/src/compile/query_engine.rs | 76 +++++++++- .../cubesql/src/compile/rewrite/rewriter.rs | 17 +-- rust/cubesql/cubesql/src/compile/router.rs | 48 +----- .../cubesql/cubesql/src/sql/compiler_cache.rs | 142 +++++++++--------- rust/cubesql/cubesql/src/sql/postgres/shim.rs | 39 ++--- 5 files changed, 172 insertions(+), 150 deletions(-) diff --git a/rust/cubesql/cubesql/src/compile/query_engine.rs b/rust/cubesql/cubesql/src/compile/query_engine.rs index 2532fa12efc48..e4af6c29ff94f 100644 --- a/rust/cubesql/cubesql/src/compile/query_engine.rs +++ b/rust/cubesql/cubesql/src/compile/query_engine.rs @@ -1,5 +1,8 @@ use crate::compile::engine::df::planner::CubeQueryPlanner; -use std::{backtrace::Backtrace, collections::HashMap, future::Future, pin::Pin, sync::Arc}; +use std::{ + backtrace::Backtrace, collections::HashMap, future::Future, pin::Pin, sync::Arc, + time::SystemTime, +}; use crate::{ compile::{ @@ -21,8 +24,9 @@ use crate::{ }, config::ConfigObj, sql::{ - compiler_cache::CompilerCache, statement::SensitiveDataSanitizer, SessionManager, - SessionState, + compiler_cache::{CompilerCache, CompilerCacheEntry}, + statement::SensitiveDataSanitizer, + SessionManager, SessionState, }, transport::{LoadRequestMeta, MetaContext, SpanId, TransportService}, CubeErrorCauseType, @@ -78,6 +82,11 @@ pub trait QueryEngine { fn sanitize_statement(&self, stmt: &Self::AstStatementType) -> Self::AstStatementType; + async fn get_cache_entry_and_refresh_cache_if_needed( + &self, + state: Arc, + ) -> Result, CompilationError>; + async fn plan( &self, stmt: Self::AstStatementType, @@ -86,6 +95,28 @@ pub trait QueryEngine { meta: Arc, state: Arc, ) -> CompilationResult<(QueryPlan, Self::PlanMetadataType)> { + let cache_entry = self + .get_cache_entry_and_refresh_cache_if_needed(state.clone()) + .await?; + + let planning_start = SystemTime::now(); + if let Some(span_id) = span_id.as_ref() { + if let Some(auth_context) = state.auth_context() { + self.transport_ref() + .log_load_state( + Some(span_id.clone()), + auth_context, + state.get_load_request_meta(), + "SQL API Query Planning".to_string(), + serde_json::json!({ + "query": span_id.query_key.clone(), + }), + ) + .await + .map_err(|e| CompilationError::internal(e.to_string()))?; + } + } + let ctx = self.create_session_ctx(state.clone())?; let cube_ctx = self.create_cube_ctx(state.clone(), meta.clone(), ctx.clone())?; @@ -144,7 +175,7 @@ pub trait QueryEngine { let mut finalized_graph = self .compiler_cache_ref() .rewrite( - state.auth_context().unwrap(), + Arc::clone(&cache_entry), cube_ctx.clone(), converter.take_egraph(), &query_params.unwrap(), @@ -192,6 +223,7 @@ pub trait QueryEngine { let result = rewriter .find_best_plan( root, + cache_entry, state.auth_context().unwrap(), qtrace, span_id.clone(), @@ -243,12 +275,31 @@ pub trait QueryEngine { // TODO: We should find what optimizers will be safety to use for OLAP queries guard.optimizer.rules = vec![]; } - if let Some(span_id) = span_id { + if let Some(span_id) = &span_id { span_id.set_is_data_query(true).await; } }; log::debug!("Rewrite: {:#?}", rewrite_plan); + + if let Some(span_id) = span_id.as_ref() { + if let Some(auth_context) = state.auth_context() { + self.transport_ref() + .log_load_state( + Some(span_id.clone()), + auth_context, + state.get_load_request_meta(), + "SQL API Query Planning Success".to_string(), + serde_json::json!({ + "query": span_id.query_key.clone(), + "duration": planning_start.elapsed().unwrap().as_millis() as u64, + }), + ) + .await + .map_err(|e| CompilationError::internal(e.to_string()))?; + } + } + let rewrite_plan = Self::evaluate_wrapped_sql( self.transport_ref().clone(), Arc::new(state.get_load_request_meta()), @@ -501,6 +552,21 @@ impl QueryEngine for SqlQueryEngine { fn sanitize_statement(&self, stmt: &Self::AstStatementType) -> Self::AstStatementType { SensitiveDataSanitizer::new().replace(stmt.clone()) } + + async fn get_cache_entry_and_refresh_cache_if_needed( + &self, + state: Arc, + ) -> Result, CompilationError> { + self.compiler_cache_ref() + .get_cache_entry_and_refresh_if_needed( + state.auth_context().ok_or_else(|| { + CompilationError::internal("Unable to get auth context".to_string()) + })?, + state.protocol.clone(), + ) + .await + .map_err(|e| CompilationError::internal(e.to_string())) + } } fn is_olap_query(parent: &LogicalPlan) -> Result { diff --git a/rust/cubesql/cubesql/src/compile/rewrite/rewriter.rs b/rust/cubesql/cubesql/src/compile/rewrite/rewriter.rs index 70346db658e19..bb7840374fae4 100644 --- a/rust/cubesql/cubesql/src/compile/rewrite/rewriter.rs +++ b/rust/cubesql/cubesql/src/compile/rewrite/rewriter.rs @@ -15,7 +15,7 @@ use crate::{ CubeContext, }, config::ConfigObj, - sql::AuthContextRef, + sql::{compiler_cache::CompilerCacheEntry, AuthContextRef}, transport::{MetaContext, SpanId}, CubeError, }; @@ -310,7 +310,7 @@ impl Rewriter { pub async fn run_rewrite_to_completion( &mut self, - auth_context: AuthContextRef, + cache_entry: Arc, qtrace: &mut Option, ) -> Result { let cube_context = self.cube_context.clone(); @@ -323,11 +323,7 @@ impl Rewriter { .sessions .server .compiler_cache - .rewrite_rules( - auth_context.clone(), - cube_context.session_state.protocol.clone(), - false, - ) + .rewrite_rules(cache_entry, false) .await?; let (plan, qtrace_egraph_iterations) = tokio::task::spawn_blocking(move || { @@ -392,6 +388,7 @@ impl Rewriter { pub async fn find_best_plan( &mut self, root: Id, + cache_entry: Arc, auth_context: AuthContextRef, qtrace: &mut Option, span_id: Option>, @@ -407,11 +404,7 @@ impl Rewriter { .sessions .server .compiler_cache - .rewrite_rules( - auth_context.clone(), - cube_context.session_state.protocol.clone(), - true, - ) + .rewrite_rules(cache_entry, true) .await?; let (plan, qtrace_egraph_iterations, qtrace_best_graph) = diff --git a/rust/cubesql/cubesql/src/compile/router.rs b/rust/cubesql/cubesql/src/compile/router.rs index 41e9e5c5213d9..6d5bf3a8f1f77 100644 --- a/rust/cubesql/cubesql/src/compile/router.rs +++ b/rust/cubesql/cubesql/src/compile/router.rs @@ -3,7 +3,7 @@ use crate::compile::{ StatusFlags, }; use sqlparser::ast; -use std::{collections::HashMap, sync::Arc, time::SystemTime}; +use std::{collections::HashMap, sync::Arc}; use crate::{ compile::{ @@ -61,50 +61,8 @@ impl QueryRouter { qtrace: &mut Option, span_id: Option>, ) -> CompilationResult { - let planning_start = SystemTime::now(); - if let Some(span_id) = span_id.as_ref() { - if let Some(auth_context) = self.state.auth_context() { - self.session_manager - .server - .transport - .log_load_state( - Some(span_id.clone()), - auth_context, - self.state.get_load_request_meta(), - "SQL API Query Planning".to_string(), - serde_json::json!({ - "query": span_id.query_key.clone(), - }), - ) - .await - .map_err(|e| CompilationError::internal(e.to_string()))?; - } - } - let result = self - .create_df_logical_plan(stmt.clone(), qtrace, span_id.clone()) - .await?; - - if let Some(span_id) = span_id.as_ref() { - if let Some(auth_context) = self.state.auth_context() { - self.session_manager - .server - .transport - .log_load_state( - Some(span_id.clone()), - auth_context, - self.state.get_load_request_meta(), - "SQL API Query Planning Success".to_string(), - serde_json::json!({ - "query": span_id.query_key.clone(), - "duration": planning_start.elapsed().unwrap().as_millis() as u64, - }), - ) - .await - .map_err(|e| CompilationError::internal(e.to_string()))?; - } - } - - return Ok(result); + self.create_df_logical_plan(stmt.clone(), qtrace, span_id.clone()) + .await } pub async fn plan( diff --git a/rust/cubesql/cubesql/src/sql/compiler_cache.rs b/rust/cubesql/cubesql/src/sql/compiler_cache.rs index 80131a5192e9a..139d85e2b165c 100644 --- a/rust/cubesql/cubesql/src/sql/compiler_cache.rs +++ b/rust/cubesql/cubesql/src/sql/compiler_cache.rs @@ -21,20 +21,18 @@ use uuid::Uuid; pub trait CompilerCache: Send + Sync + Debug { async fn rewrite_rules( &self, - ctx: AuthContextRef, - protocol: DatabaseProtocol, + cache_entry: Arc, eval_stable_functions: bool, ) -> Result>, CubeError>; async fn meta( &self, - ctx: AuthContextRef, - protocol: DatabaseProtocol, + cache_entry: Arc, ) -> Result, CubeError>; async fn parameterized_rewrite( &self, - ctx: AuthContextRef, + cache_entry: Arc, cube_context: Arc, input_plan: CubeEGraph, qtrace: &mut Option, @@ -42,12 +40,18 @@ pub trait CompilerCache: Send + Sync + Debug { async fn rewrite( &self, - ctx: AuthContextRef, + cache_entry: Arc, cube_context: Arc, input_plan: CubeEGraph, param_values: &HashMap, qtrace: &mut Option, ) -> Result; + + async fn get_cache_entry_and_refresh_if_needed( + &self, + ctx: AuthContextRef, + protocol: DatabaseProtocol, + ) -> Result, CubeError>; } #[derive(Debug)] @@ -70,12 +74,9 @@ crate::di_service!(CompilerCacheImpl, [CompilerCache]); impl CompilerCache for CompilerCacheImpl { async fn rewrite_rules( &self, - ctx: AuthContextRef, - protocol: DatabaseProtocol, + cache_entry: Arc, eval_stable_functions: bool, ) -> Result>, CubeError> { - let cache_entry = self.get_cache_entry(ctx.clone(), protocol).await?; - let rewrite_rules = { cache_entry .rewrite_rules @@ -105,33 +106,28 @@ impl CompilerCache for CompilerCacheImpl { async fn meta( &self, - ctx: AuthContextRef, - protocol: DatabaseProtocol, + cache_entry: Arc, ) -> Result, CubeError> { - let cache_entry = self.get_cache_entry(ctx.clone(), protocol).await?; Ok(cache_entry.meta_context.clone()) } async fn parameterized_rewrite( &self, - ctx: AuthContextRef, + cache_entry: Arc, cube_context: Arc, parameterized_graph: CubeEGraph, qtrace: &mut Option, ) -> Result { - let cache_entry = self - .get_cache_entry(ctx.clone(), cube_context.session_state.protocol.clone()) - .await?; - let graph_key = egraph_hash(¶meterized_graph, None); + let cache_entry_clone = Arc::clone(&cache_entry); let mut rewrites_cache_lock = cache_entry.parameterized_cache.lock().await; if let Some(rewrite_entry) = rewrites_cache_lock.get(&graph_key) { Ok(rewrite_entry.clone()) } else { let mut rewriter = Rewriter::new(parameterized_graph, cube_context); let rewrite_entry = rewriter - .run_rewrite_to_completion(ctx.clone(), qtrace) + .run_rewrite_to_completion(cache_entry_clone, qtrace) .await?; rewrites_cache_lock.put(graph_key, rewrite_entry.clone()); Ok(rewrite_entry) @@ -140,7 +136,7 @@ impl CompilerCache for CompilerCacheImpl { async fn rewrite( &self, - ctx: AuthContextRef, + cache_entry: Arc, cube_context: Arc, input_plan: CubeEGraph, param_values: &HashMap, @@ -149,85 +145,91 @@ impl CompilerCache for CompilerCacheImpl { if !self.config_obj.enable_rewrite_cache() { let mut rewriter = Rewriter::new(input_plan, cube_context); rewriter.add_param_values(param_values)?; - return Ok(rewriter.run_rewrite_to_completion(ctx, qtrace).await?); + return Ok(rewriter + .run_rewrite_to_completion(cache_entry, qtrace) + .await?); } - let cache_entry = self - .get_cache_entry(ctx.clone(), cube_context.session_state.protocol.clone()) - .await?; let graph_key = egraph_hash(&input_plan, Some(param_values)); + let cache_entry_clone = Arc::clone(&cache_entry); let mut rewrites_cache_lock = cache_entry.queries_cache.lock().await; if let Some(plan) = rewrites_cache_lock.get(&graph_key) { Ok(plan.clone()) } else { let graph = if self.config_obj.enable_parameterized_rewrite_cache() { - self.parameterized_rewrite(ctx.clone(), cube_context.clone(), input_plan, qtrace) - .await? + self.parameterized_rewrite( + Arc::clone(&cache_entry), + cube_context.clone(), + input_plan, + qtrace, + ) + .await? } else { input_plan }; let mut rewriter = Rewriter::new(graph, cube_context); rewriter.add_param_values(param_values)?; - let final_plan = rewriter.run_rewrite_to_completion(ctx, qtrace).await?; + let final_plan = rewriter + .run_rewrite_to_completion(cache_entry_clone, qtrace) + .await?; rewrites_cache_lock.put(graph_key, final_plan.clone()); Ok(final_plan) } } -} - -impl CompilerCacheImpl { - pub fn new(config_obj: Arc, transport: Arc) -> Self { - let compiler_cache_size = config_obj.compiler_cache_size(); - CompilerCacheImpl { - config_obj, - transport, - compiler_id_to_entry: MutexAsync::new(LruCache::new( - NonZeroUsize::new(compiler_cache_size).unwrap(), - )), - } - } - pub async fn get_cache_entry( + async fn get_cache_entry_and_refresh_if_needed( &self, ctx: AuthContextRef, protocol: DatabaseProtocol, ) -> Result, CubeError> { let compiler_id = self.transport.compiler_id(ctx.clone()).await?; - let cache_entry = { - self.compiler_id_to_entry + { + if let Some(cache_entry) = self + .compiler_id_to_entry .lock() .await .get(&(compiler_id, protocol.clone())) - .cloned() - }; - // Double checked locking - let cache_entry = if let Some(cache_entry) = cache_entry { - cache_entry - } else { - let meta_context = self.transport.meta(ctx.clone()).await?; + { + return Ok(Arc::clone(cache_entry)); + } + } + + let meta_context = self.transport.meta(ctx).await?; + let cache_entry = { let mut compiler_id_to_entry = self.compiler_id_to_entry.lock().await; - compiler_id_to_entry - .get(&(meta_context.compiler_id, protocol.clone())) - .cloned() - .unwrap_or_else(|| { - let cache_entry = Arc::new(CompilerCacheEntry { - meta_context: meta_context.clone(), - rewrite_rules: RWLockAsync::new(HashMap::new()), - parameterized_cache: MutexAsync::new(LruCache::new( - NonZeroUsize::new(self.config_obj.query_cache_size()).unwrap(), - )), - queries_cache: MutexAsync::new(LruCache::new( - NonZeroUsize::new(self.config_obj.query_cache_size()).unwrap(), - )), - }); - compiler_id_to_entry.put( - (meta_context.compiler_id.clone(), protocol.clone()), - cache_entry.clone(), - ); - cache_entry - }) + let cache_entry = Arc::new(CompilerCacheEntry { + meta_context: meta_context.clone(), + rewrite_rules: RWLockAsync::new(HashMap::new()), + parameterized_cache: MutexAsync::new(LruCache::new( + NonZeroUsize::new(self.config_obj.query_cache_size()).unwrap(), + )), + queries_cache: MutexAsync::new(LruCache::new( + NonZeroUsize::new(self.config_obj.query_cache_size()).unwrap(), + )), + }); + if !compiler_id_to_entry.contains(&(meta_context.compiler_id, protocol.clone())) { + compiler_id_to_entry.put( + (meta_context.compiler_id.clone(), protocol.clone()), + cache_entry.clone(), + ); + } + cache_entry }; + Ok(cache_entry) } } + +impl CompilerCacheImpl { + pub fn new(config_obj: Arc, transport: Arc) -> Self { + let compiler_cache_size = config_obj.compiler_cache_size(); + CompilerCacheImpl { + config_obj, + transport, + compiler_id_to_entry: MutexAsync::new(LruCache::new( + NonZeroUsize::new(compiler_cache_size).unwrap(), + )), + } + } +} diff --git a/rust/cubesql/cubesql/src/sql/postgres/shim.rs b/rust/cubesql/cubesql/src/sql/postgres/shim.rs index 55275b88cbfbe..425e87af1dd91 100644 --- a/rust/cubesql/cubesql/src/sql/postgres/shim.rs +++ b/rust/cubesql/cubesql/src/sql/postgres/shim.rs @@ -12,6 +12,7 @@ use crate::{ CommandCompletion, CompilationError, DatabaseProtocol, QueryPlan, StatusFlags, }, sql::{ + compiler_cache::CompilerCacheEntry, df_type_to_pg_tid, extended::{Cursor, Portal, PortalBatch, PortalFrom}, statement::{PostgresStatementParamsFinder, StatementPlaceholderReplacer}, @@ -240,6 +241,20 @@ impl AsyncPostgresShim { return Ok(()); } + async fn get_cache_entry_and_refresh_cache_if_needed( + &self, + ) -> Result, CubeError> { + self.session + .session_manager + .server + .compiler_cache + .get_cache_entry_and_refresh_if_needed( + self.auth_context()?, + self.session.state.protocol.clone(), + ) + .await + } + pub async fn run_on( fast_shutdown_interruptor: CancellationToken, semifast_shutdown_interruptor: CancellationToken, @@ -1058,12 +1073,8 @@ impl AsyncPostgresShim { source_statement.bind(body.to_bind_values(¶meters)?)?; drop(statements_guard); - let meta = self - .session - .server - .compiler_cache - .meta(self.auth_context()?, self.session.state.protocol.clone()) - .await?; + let cache_entry = self.get_cache_entry_and_refresh_cache_if_needed().await?; + let meta = self.session.server.compiler_cache.meta(cache_entry).await?; let plan = convert_statement_to_cube_query( prepared_statement, @@ -1171,12 +1182,8 @@ impl AsyncPostgresShim { .map(|param| param.coltype.to_pg_tid()) .collect(); - let meta = self - .session - .server - .compiler_cache - .meta(self.auth_context()?, self.session.state.protocol.clone()) - .await?; + let cache_entry = self.get_cache_entry_and_refresh_cache_if_needed().await?; + let meta = self.session.server.compiler_cache.meta(cache_entry).await?; let stmt_replacer = StatementPlaceholderReplacer::new(); let hacked_query = stmt_replacer.replace(query.clone())?; @@ -1794,12 +1801,8 @@ impl AsyncPostgresShim { qtrace: &mut Option, span_id: Option>, ) -> Result<(), ConnectionError> { - let meta = self - .session - .server - .compiler_cache - .meta(self.auth_context()?, self.session.state.protocol.clone()) - .await?; + let cache_entry = self.get_cache_entry_and_refresh_cache_if_needed().await?; + let meta = self.session.server.compiler_cache.meta(cache_entry).await?; let statements = parse_sql_to_statements(&query.to_string(), DatabaseProtocol::PostgreSQL, qtrace)?;