From 8d6b0d0beb798a07af993f189c5135684bba827f Mon Sep 17 00:00:00 2001 From: Mikhail Cheshkov Date: Wed, 11 Sep 2024 16:25:18 +0300 Subject: [PATCH] refactor(cubesql): Extract CubeScanWrappedSqlNode from CubeScanWrapperNode --- .../cubesql/src/compile/engine/df/scan.rs | 31 +-- .../cubesql/src/compile/engine/df/wrapper.rs | 100 +++++-- rust/cubesql/cubesql/src/compile/mod.rs | 248 ++++-------------- .../cubesql/src/compile/query_engine.rs | 2 + .../cubesql/src/compile/test/test_wrapper.rs | 188 ++++--------- .../cubesql/cubesql/src/compile/test/utils.rs | 17 +- 6 files changed, 201 insertions(+), 385 deletions(-) diff --git a/rust/cubesql/cubesql/src/compile/engine/df/scan.rs b/rust/cubesql/cubesql/src/compile/engine/df/scan.rs index 11d7f03695d5f..92095825369e5 100644 --- a/rust/cubesql/cubesql/src/compile/engine/df/scan.rs +++ b/rust/cubesql/cubesql/src/compile/engine/df/scan.rs @@ -30,7 +30,7 @@ use std::{ use crate::{ compile::{ - engine::df::wrapper::{CubeScanWrapperNode, SqlQuery}, + engine::df::wrapper::{CubeScanWrappedSqlNode, CubeScanWrapperNode, SqlQuery}, rewrite::WrappedSelectType, test::find_cube_scans_deep_search, }, @@ -394,35 +394,32 @@ impl ExtensionPlanner for CubeScanExtensionPlanner { config_obj: self.config_obj.clone(), })) } else if let Some(wrapper_node) = node.as_any().downcast_ref::() { + return Err(DataFusionError::Internal(format!( + "CubeScanWrapperNode is not executable, SQL should be generated first with QueryEngine::evaluate_wrapped_sql: {:?}", + wrapper_node + ))); + } else if let Some(wrapped_sql_node) = + node.as_any().downcast_ref::() + { // TODO // assert_eq!(logical_inputs.len(), 0, "Inconsistent number of inputs"); // assert_eq!(physical_inputs.len(), 0, "Inconsistent number of inputs"); let scan_node = - find_cube_scans_deep_search(wrapper_node.wrapped_plan.clone(), false) + find_cube_scans_deep_search(wrapped_sql_node.wrapped_plan.clone(), false) .into_iter() .next() .ok_or(DataFusionError::Internal(format!( "No cube scans found in wrapper node: {:?}", - wrapper_node + wrapped_sql_node )))?; - let schema = SchemaRef::new(wrapper_node.schema().as_ref().into()); + let schema = SchemaRef::new(wrapped_sql_node.schema().as_ref().into()); Some(Arc::new(CubeScanExecutionPlan { schema, - member_fields: wrapper_node.member_fields.as_ref().ok_or_else(|| { - DataFusionError::Internal(format!( - "Member fields are not set for wrapper node. Optimization wasn't performed: {:?}", - wrapper_node - )) - })?.clone(), + member_fields: wrapped_sql_node.member_fields.clone(), transport: self.transport.clone(), - request: wrapper_node.request.clone().unwrap_or(scan_node.request.clone()), - wrapped_sql: Some(wrapper_node.wrapped_sql.as_ref().ok_or_else(|| { - DataFusionError::Internal(format!( - "Wrapped SQL is not set for wrapper node. Optimization wasn't performed: {:?}", - wrapper_node - )) - })?.clone()), + request: wrapped_sql_node.request.clone(), + wrapped_sql: Some(wrapped_sql_node.wrapped_sql.clone()), auth_context: scan_node.auth_context.clone(), options: scan_node.options.clone(), meta: self.meta.clone(), diff --git a/rust/cubesql/cubesql/src/compile/engine/df/wrapper.rs b/rust/cubesql/cubesql/src/compile/engine/df/wrapper.rs index 3cf95d0ba8a29..2ebdf9b409bf4 100644 --- a/rust/cubesql/cubesql/src/compile/engine/df/wrapper.rs +++ b/rust/cubesql/cubesql/src/compile/engine/df/wrapper.rs @@ -201,14 +201,75 @@ impl SqlQuery { } } +#[derive(Clone, Debug)] +pub struct CubeScanWrappedSqlNode { + // TODO maybe replace wrapped plan with schema + scan_node + pub wrapped_plan: Arc, + pub wrapped_sql: SqlQuery, + pub request: TransportLoadRequestQuery, + pub member_fields: Vec, +} + +impl CubeScanWrappedSqlNode { + pub fn new( + wrapped_plan: Arc, + sql: SqlQuery, + request: TransportLoadRequestQuery, + member_fields: Vec, + ) -> Self { + Self { + wrapped_plan, + wrapped_sql: sql, + request, + member_fields, + } + } +} + +impl UserDefinedLogicalNode for CubeScanWrappedSqlNode { + fn as_any(&self) -> &dyn Any { + self + } + + fn inputs(&self) -> Vec<&LogicalPlan> { + vec![] + } + + fn schema(&self) -> &DFSchemaRef { + self.wrapped_plan.schema() + } + + fn expressions(&self) -> Vec { + vec![] + } + + fn fmt_for_explain(&self, f: &mut fmt::Formatter) -> fmt::Result { + // TODO figure out nice plan for wrapped plan + write!(f, "CubeScanWrappedSql") + } + + fn from_template( + &self, + exprs: &[datafusion::logical_plan::Expr], + inputs: &[datafusion::logical_plan::LogicalPlan], + ) -> std::sync::Arc { + assert_eq!(inputs.len(), 0, "input size inconsistent"); + assert_eq!(exprs.len(), 0, "expression size inconsistent"); + + Arc::new(CubeScanWrappedSqlNode { + wrapped_plan: self.wrapped_plan.clone(), + wrapped_sql: self.wrapped_sql.clone(), + request: self.request.clone(), + member_fields: self.member_fields.clone(), + }) + } +} + #[derive(Debug, Clone)] pub struct CubeScanWrapperNode { pub wrapped_plan: Arc, pub meta: Arc, pub auth_context: AuthContextRef, - pub wrapped_sql: Option, - pub request: Option, - pub member_fields: Option>, pub span_id: Option>, pub config_obj: Arc, } @@ -225,31 +286,10 @@ impl CubeScanWrapperNode { wrapped_plan, meta, auth_context, - wrapped_sql: None, - request: None, - member_fields: None, span_id, config_obj, } } - - pub fn with_sql_and_request( - &self, - sql: SqlQuery, - request: TransportLoadRequestQuery, - member_fields: Vec, - ) -> Self { - Self { - wrapped_plan: self.wrapped_plan.clone(), - meta: self.meta.clone(), - auth_context: self.auth_context.clone(), - wrapped_sql: Some(sql), - request: Some(request), - member_fields: Some(member_fields), - span_id: self.span_id.clone(), - config_obj: self.config_obj.clone(), - } - } } fn expr_name(e: &Expr, schema: &Arc) -> Result { @@ -317,7 +357,7 @@ impl CubeScanWrapperNode { &self, transport: Arc, load_request_meta: Arc, - ) -> result::Result { + ) -> result::Result { let schema = self.schema(); let wrapped_plan = self.wrapped_plan.clone(); let (sql, request, member_fields) = Self::generate_sql_for_node( @@ -361,7 +401,12 @@ impl CubeScanWrapperNode { sql.finalize_query(sql_templates).map_err(|e| CubeError::internal(e.to_string()))?; Ok((sql, request, member_fields)) })?; - Ok(self.with_sql_and_request(sql, request, member_fields)) + Ok(CubeScanWrappedSqlNode::new( + self.wrapped_plan.clone(), + sql, + request, + member_fields, + )) } pub fn set_max_limit_for_node(self, node: Arc) -> Arc { @@ -2242,9 +2287,6 @@ impl UserDefinedLogicalNode for CubeScanWrapperNode { wrapped_plan: self.wrapped_plan.clone(), meta: self.meta.clone(), auth_context: self.auth_context.clone(), - wrapped_sql: self.wrapped_sql.clone(), - request: self.request.clone(), - member_fields: self.member_fields.clone(), span_id: self.span_id.clone(), config_obj: self.config_obj.clone(), }) diff --git a/rust/cubesql/cubesql/src/compile/mod.rs b/rust/cubesql/cubesql/src/compile/mod.rs index e0b4974a99166..96784e627e207 100644 --- a/rust/cubesql/cubesql/src/compile/mod.rs +++ b/rust/cubesql/cubesql/src/compile/mod.rs @@ -336,11 +336,7 @@ mod tests { ) .await.as_logical_plan(); - let sql = logical_plan - .find_cube_scan_wrapper() - .wrapped_sql - .unwrap() - .sql; + let sql = logical_plan.find_cube_scan_wrapped_sql().wrapped_sql.sql; assert!(sql.contains("LOWER(")); assert!(sql.contains(" IN (")); @@ -351,11 +347,7 @@ mod tests { ) .await.as_logical_plan(); - let sql = logical_plan - .find_cube_scan_wrapper() - .wrapped_sql - .unwrap() - .sql; + let sql = logical_plan.find_cube_scan_wrapped_sql().wrapped_sql.sql; assert!(sql.contains("LOWER(")); assert!(sql.contains(" IN (")); @@ -374,11 +366,7 @@ mod tests { DatabaseProtocol::PostgreSQL, ).await.as_logical_plan(); - let sql = logical_plan - .find_cube_scan_wrapper() - .wrapped_sql - .unwrap() - .sql; + let sql = logical_plan.find_cube_scan_wrapped_sql().wrapped_sql.sql; assert!(sql.contains("LOWER(")); } @@ -2777,17 +2765,15 @@ limit assert!(query_plan .as_logical_plan() - .find_cube_scan_wrapper() + .find_cube_scan_wrapped_sql() .wrapped_sql - .unwrap() .sql .contains("sixteen_charchar_1")); assert!(query_plan .as_logical_plan() - .find_cube_scan_wrapper() + .find_cube_scan_wrapped_sql() .wrapped_sql - .unwrap() .sql .contains("sixteen_charchar_2")); } @@ -6948,11 +6934,7 @@ ORDER BY DatabaseProtocol::PostgreSQL ).await.as_logical_plan(); - let sql = logical_plan - .find_cube_scan_wrapper() - .wrapped_sql - .unwrap() - .sql; + let sql = logical_plan.find_cube_scan_wrapped_sql().wrapped_sql.sql; assert!( sql.contains(expected_search_expr), "cast_expr is {}, expected_search_expr is {}", @@ -7256,9 +7238,8 @@ ORDER BY assert_eq!( query_plan .as_logical_plan() - .find_cube_scan_wrapper() - .request - .unwrap(), + .find_cube_scan_wrapped_sql() + .request, V1LoadRequestQuery { measures: Some(vec![ json!({ @@ -7371,9 +7352,8 @@ ORDER BY "source"."str0" ASC ); assert!(!query_plan .as_logical_plan() - .find_cube_scan_wrapper() + .find_cube_scan_wrapped_sql() .wrapped_sql - .unwrap() .sql .contains("ungrouped")); } @@ -10989,11 +10969,7 @@ ORDER BY "source"."str0" ASC .await .as_logical_plan(); - let sql = logical_plan - .find_cube_scan_wrapper() - .wrapped_sql - .unwrap() - .sql; + let sql = logical_plan.find_cube_scan_wrapped_sql().wrapped_sql.sql; // check wrapping for `LOWER(..) <> .. OR .. IS NULL` let re = Regex::new(r"LOWER ?\(.+\) != .+ OR .+ IS NULL").unwrap(); @@ -11026,11 +11002,7 @@ ORDER BY "source"."str0" ASC .await .as_logical_plan(); - let sql = logical_plan - .find_cube_scan_wrapper() - .wrapped_sql - .unwrap() - .sql; + let sql = logical_plan.find_cube_scan_wrapped_sql().wrapped_sql.sql; // check wrapping for `NOT(LOWER(..) IN (..))` let re = Regex::new(r"NOT.+LOWER ?\(.+\).* IN ").unwrap(); @@ -11326,11 +11298,7 @@ ORDER BY "source"."str0" ASC .await .as_logical_plan(); - let sql = logical_plan - .find_cube_scan_wrapper() - .wrapped_sql - .unwrap() - .sql; + let sql = logical_plan.find_cube_scan_wrapped_sql().wrapped_sql.sql; // check wrapping for `NOT(LOWER(..) IN (..)) OR NOT(.. IS NOT NULL)` let re = Regex::new(r"NOT.+LOWER ?\(.+\) IN .+\) OR NOT.+ IS NOT NULL").unwrap(); @@ -11650,9 +11618,8 @@ ORDER BY "source"."str0" ASC assert_eq!( logical_plan - .find_cube_scan_wrapper() - .request - .unwrap(), + .find_cube_scan_wrapped_sql() + .request, V1LoadRequestQuery { measures: Some(vec![]), dimensions: Some(vec![ @@ -11936,9 +11903,8 @@ ORDER BY "source"."str0" ASC assert_eq!( query_plan .as_logical_plan() - .find_cube_scan_wrapper() - .request - .unwrap(), + .find_cube_scan_wrapped_sql() + .request, V1LoadRequestQuery { measures: Some(vec![]), dimensions: Some(vec![ @@ -12243,11 +12209,7 @@ ORDER BY "source"."str0" ASC .await .as_logical_plan(); - let sql = logical_plan - .find_cube_scan_wrapper() - .wrapped_sql - .unwrap() - .sql; + let sql = logical_plan.find_cube_scan_wrapped_sql().wrapped_sql.sql; // check wrapping for `NOT(.. IS NULL OR LOWER(..) IN)` let re = Regex::new(r"NOT \(.+ IS NULL OR .*LOWER\(.+ IN ").unwrap(); @@ -12294,11 +12256,7 @@ ORDER BY "source"."str0" ASC let re = Regex::new(r"\(LOWER ?\(.+\) = .+ OR .+LOWER ?\(.+\) = .+\) IN \(TRUE, FALSE\)") .unwrap(); - let sql = logical_plan - .find_cube_scan_wrapper() - .wrapped_sql - .unwrap() - .sql; + let sql = logical_plan.find_cube_scan_wrapped_sql().wrapped_sql.sql; assert!(re.is_match(&sql)); } @@ -12761,7 +12719,7 @@ ORDER BY "source"."str0" ASC let end_date = chrono::Utc::now().date_naive(); let start_date = end_date - chrono::Duration::days(30); assert_eq!( - logical_plan.find_cube_scan_wrapper().request.unwrap(), + logical_plan.find_cube_scan_wrapped_sql().request, V1LoadRequestQuery { measures: Some(vec![ json!({ @@ -12877,9 +12835,8 @@ ORDER BY "source"."str0" ASC let logical_plan = query_plan.as_logical_plan(); assert!(logical_plan - .find_cube_scan_wrapper() + .find_cube_scan_wrapped_sql() .wrapped_sql - .unwrap() .sql .contains("LEFT")); } @@ -13228,9 +13185,8 @@ ORDER BY "source"."str0" ASC let logical_plan = query_plan.as_logical_plan(); assert!(logical_plan - .find_cube_scan_wrapper() + .find_cube_scan_wrapped_sql() .wrapped_sql - .unwrap() .sql .contains("EXTRACT")); } @@ -13308,9 +13264,8 @@ ORDER BY "source"."str0" ASC let logical_plan = query_plan.as_logical_plan(); assert!(logical_plan - .find_cube_scan_wrapper() + .find_cube_scan_wrapped_sql() .wrapped_sql - .unwrap() .sql .contains("EXTRACT")); } @@ -13337,9 +13292,8 @@ ORDER BY "source"."str0" ASC let logical_plan = query_plan.as_logical_plan(); assert!(logical_plan - .find_cube_scan_wrapper() + .find_cube_scan_wrapped_sql() .wrapped_sql - .unwrap() .sql .contains("EXTRACT")); } @@ -13366,9 +13320,8 @@ ORDER BY "source"."str0" ASC let logical_plan = query_plan.as_logical_plan(); assert!(logical_plan - .find_cube_scan_wrapper() + .find_cube_scan_wrapped_sql() .wrapped_sql - .unwrap() .sql .contains("EXTRACT")); } @@ -13390,17 +13343,12 @@ ORDER BY "source"."str0" ASC let logical_plan = query_plan.as_logical_plan(); assert!( logical_plan - .find_cube_scan_wrapper() + .find_cube_scan_wrapped_sql() .wrapped_sql - .unwrap() .sql .contains("OVER"), "SQL should contain 'OVER': {}", - logical_plan - .find_cube_scan_wrapper() - .wrapped_sql - .unwrap() - .sql + logical_plan.find_cube_scan_wrapped_sql().wrapped_sql.sql ); let physical_plan = query_plan.as_physical_plan().await.unwrap(); @@ -13427,32 +13375,22 @@ ORDER BY "source"."str0" ASC let logical_plan = query_plan.as_logical_plan(); assert!( logical_plan - .find_cube_scan_wrapper() + .find_cube_scan_wrapped_sql() .wrapped_sql - .unwrap() .sql .contains("long_l_1"), "SQL should contain long_l_1: {}", - logical_plan - .find_cube_scan_wrapper() - .wrapped_sql - .unwrap() - .sql + logical_plan.find_cube_scan_wrapped_sql().wrapped_sql.sql ); assert!( logical_plan - .find_cube_scan_wrapper() + .find_cube_scan_wrapped_sql() .wrapped_sql - .unwrap() .sql .contains("long_l_1"), "SQL should contain long_l_2: {}", - logical_plan - .find_cube_scan_wrapper() - .wrapped_sql - .unwrap() - .sql + logical_plan.find_cube_scan_wrapped_sql().wrapped_sql.sql ); let physical_plan = query_plan.as_physical_plan().await.unwrap(); @@ -13478,9 +13416,8 @@ ORDER BY "source"."str0" ASC let logical_plan = query_plan.as_logical_plan(); assert!(logical_plan - .find_cube_scan_wrapper() + .find_cube_scan_wrapped_sql() .wrapped_sql - .unwrap() .sql .contains("CURRENT_DATE()")); @@ -13534,11 +13471,7 @@ ORDER BY "source"."str0" ASC .await .as_logical_plan(); - let sql = logical_plan - .find_cube_scan_wrapper() - .wrapped_sql - .unwrap() - .sql; + let sql = logical_plan.find_cube_scan_wrapped_sql().wrapped_sql.sql; // check if contains `CAST(EXTRACT(YEAR FROM ..) || .. || .. || ..)` let re = Regex::new(r"CAST.+EXTRACT.+YEAR FROM(.+ \|\|){3}").unwrap(); @@ -13691,11 +13624,7 @@ ORDER BY "source"."str0" ASC let logical_plan = query_plan.as_logical_plan(); if Rewriter::sql_push_down_enabled() { - let sql = logical_plan - .find_cube_scan_wrapper() - .wrapped_sql - .unwrap() - .sql; + let sql = logical_plan.find_cube_scan_wrapped_sql().wrapped_sql.sql; assert!(sql.contains("EXTRACT(YEAR")); assert!(sql.contains("EXTRACT(MONTH")); @@ -13799,11 +13728,7 @@ ORDER BY "source"."str0" ASC // TODO: split on complex expressions? // CAST(CAST(ta_1.order_date AS Date32) - CAST(CAST(Utf8("1970-01-01") AS Date32) AS Date32) + Int64(3) AS Decimal(38, 10)) if Rewriter::sql_push_down_enabled() { - let sql = logical_plan - .find_cube_scan_wrapper() - .wrapped_sql - .unwrap() - .sql; + let sql = logical_plan.find_cube_scan_wrapped_sql().wrapped_sql.sql; if Rewriter::top_down_extractor_enabled() { assert!(sql.contains("LIMIT 1000")); } else { @@ -14175,11 +14100,7 @@ ORDER BY "source"."str0" ASC let logical_plan = query_plan.as_logical_plan(); if Rewriter::sql_push_down_enabled() { - let sql = logical_plan - .find_cube_scan_wrapper() - .wrapped_sql - .unwrap() - .sql; + let sql = logical_plan.find_cube_scan_wrapped_sql().wrapped_sql.sql; assert!(sql.contains("LIMIT 101")); assert!(sql.contains("ORDER BY")); @@ -14289,9 +14210,8 @@ ORDER BY "source"."str0" ASC let logical_plan = query_plan.as_logical_plan(); assert!(logical_plan - .find_cube_scan_wrapper() + .find_cube_scan_wrapped_sql() .wrapped_sql - .unwrap() .sql .contains("NOT IN (")); } @@ -14356,9 +14276,8 @@ ORDER BY "source"."str0" ASC let logical_plan = query_plan.as_logical_plan(); assert!(logical_plan - .find_cube_scan_wrapper() + .find_cube_scan_wrapped_sql() .wrapped_sql - .unwrap() .sql .contains("NOT (")); } @@ -14390,9 +14309,8 @@ ORDER BY "source"."str0" ASC let logical_plan = query_plan.as_logical_plan(); assert!(logical_plan - .find_cube_scan_wrapper() + .find_cube_scan_wrapped_sql() .wrapped_sql - .unwrap() .sql .contains("DATEDIFF(day,")); @@ -14419,11 +14337,7 @@ ORDER BY "source"."str0" ASC ); let logical_plan = query_plan.as_logical_plan(); - let sql = logical_plan - .find_cube_scan_wrapper() - .wrapped_sql - .unwrap() - .sql; + let sql = logical_plan.find_cube_scan_wrapped_sql().wrapped_sql.sql; assert!(sql.contains("DATETIME_DIFF(CAST(")); assert!(sql.contains("day)")); @@ -14450,11 +14364,7 @@ ORDER BY "source"."str0" ASC ); let logical_plan = query_plan.as_logical_plan(); - let sql = logical_plan - .find_cube_scan_wrapper() - .wrapped_sql - .unwrap() - .sql; + let sql = logical_plan.find_cube_scan_wrapped_sql().wrapped_sql.sql; assert!(sql.contains("DATEDIFF(day,")); assert!(sql.contains("DATE_TRUNC('day',")); @@ -14481,11 +14391,7 @@ ORDER BY "source"."str0" ASC ); let logical_plan = query_plan.as_logical_plan(); - let sql = logical_plan - .find_cube_scan_wrapper() - .wrapped_sql - .unwrap() - .sql; + let sql = logical_plan.find_cube_scan_wrapped_sql().wrapped_sql.sql; assert!(sql.contains("CASE WHEN LOWER('day')")); assert!(sql.contains("WHEN 'year' THEN 12 WHEN 'quarter' THEN 3 WHEN 'month' THEN 1 END")); assert!(sql.contains("EXTRACT(EPOCH FROM")); @@ -14520,9 +14426,8 @@ ORDER BY "source"."str0" ASC let logical_plan = query_plan.as_logical_plan(); assert!(logical_plan - .find_cube_scan_wrapper() + .find_cube_scan_wrapped_sql() .wrapped_sql - .unwrap() .sql .contains("DATEADD(day, 7,")); @@ -14549,11 +14454,7 @@ ORDER BY "source"."str0" ASC ); let logical_plan = query_plan.as_logical_plan(); - let sql = logical_plan - .find_cube_scan_wrapper() - .wrapped_sql - .unwrap() - .sql; + let sql = logical_plan.find_cube_scan_wrapped_sql().wrapped_sql.sql; assert!(sql.contains("DATETIME_ADD(CAST(")); assert!(sql.contains("INTERVAL 7 day)")); @@ -14581,11 +14482,7 @@ ORDER BY "source"."str0" ASC ); let logical_plan = query_plan.as_logical_plan(); - let sql = logical_plan - .find_cube_scan_wrapper() - .wrapped_sql - .unwrap() - .sql; + let sql = logical_plan.find_cube_scan_wrapped_sql().wrapped_sql.sql; assert!(sql.contains("+ '7 day'::interval")); } @@ -14661,9 +14558,8 @@ ORDER BY "source"."str0" ASC let logical_plan = query_plan.as_logical_plan(); assert!(logical_plan - .find_cube_scan_wrapper() + .find_cube_scan_wrapped_sql() .wrapped_sql - .unwrap() .sql .contains("DATE(")); } @@ -14700,9 +14596,8 @@ ORDER BY "source"."str0" ASC let logical_plan = query_plan.as_logical_plan(); assert!(logical_plan - .find_cube_scan_wrapper() + .find_cube_scan_wrapped_sql() .wrapped_sql - .unwrap() .sql .contains("EXTRACT(MONTH FROM ")); } @@ -14743,11 +14638,7 @@ ORDER BY "source"."str0" ASC ); let logical_plan = query_plan.as_logical_plan(); - let sql = logical_plan - .find_cube_scan_wrapper() - .wrapped_sql - .unwrap() - .sql; + let sql = logical_plan.find_cube_scan_wrapped_sql().wrapped_sql.sql; assert!(sql.contains("order_date")); assert!(sql.contains("EXTRACT(DAY FROM")) } @@ -14863,11 +14754,7 @@ LIMIT {{ limit }}{% endif %}"#.to_string(), ); let logical_plan = query_plan.as_logical_plan(); - let sql = logical_plan - .find_cube_scan_wrapper() - .wrapped_sql - .unwrap() - .sql; + let sql = logical_plan.find_cube_scan_wrapped_sql().wrapped_sql.sql; assert!(sql.contains("OFFSET 1\nLIMIT 2")); } @@ -15114,9 +15001,8 @@ LIMIT {{ limit }}{% endif %}"#.to_string(), let logical_plan = query_plan.as_logical_plan(); assert!(logical_plan - .find_cube_scan_wrapper() + .find_cube_scan_wrapped_sql() .wrapped_sql - .unwrap() .sql .contains("SELECT DISTINCT ")); @@ -15194,9 +15080,8 @@ LIMIT {{ limit }}{% endif %}"#.to_string(), let logical_plan = query_plan.as_logical_plan(); assert!(logical_plan - .find_cube_scan_wrapper() + .find_cube_scan_wrapped_sql() .wrapped_sql - .unwrap() .sql .contains("ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW")); @@ -15720,9 +15605,8 @@ LIMIT {{ limit }}{% endif %}"#.to_string(), let logical_plan = query_plan.as_logical_plan(); assert!(logical_plan - .find_cube_scan_wrapper() + .find_cube_scan_wrapped_sql() .wrapped_sql - .unwrap() .sql .contains("LIMIT 250")); @@ -16003,11 +15887,7 @@ LIMIT {{ limit }}{% endif %}"#.to_string(), ); let logical_plan = query_plan.as_logical_plan(); - let sql = logical_plan - .find_cube_scan_wrapper() - .wrapped_sql - .unwrap() - .sql; + let sql = logical_plan.find_cube_scan_wrapped_sql().wrapped_sql.sql; assert!(sql.contains(" AS STRING)")); assert!(sql.contains(" AS FLOAT)")); assert!(sql.contains(" AS DOUBLE)")); @@ -16035,11 +15915,7 @@ LIMIT {{ limit }}{% endif %}"#.to_string(), ); let logical_plan = query_plan.as_logical_plan(); - let sql = logical_plan - .find_cube_scan_wrapper() - .wrapped_sql - .unwrap() - .sql; + let sql = logical_plan.find_cube_scan_wrapped_sql().wrapped_sql.sql; assert!(sql.contains(" AS STRING)")); assert!(sql.contains(" AS FLOAT64)")); assert!(sql.contains(" AS BIGDECIMAL(38,10))")); @@ -16063,11 +15939,7 @@ LIMIT {{ limit }}{% endif %}"#.to_string(), ); let logical_plan = query_plan.as_logical_plan(); - let sql = logical_plan - .find_cube_scan_wrapper() - .wrapped_sql - .unwrap() - .sql; + let sql = logical_plan.find_cube_scan_wrapped_sql().wrapped_sql.sql; assert!(sql.contains(" AS TEXT)")); assert!(sql.contains(" AS REAL)")); assert!(sql.contains(" AS DOUBLE PRECISION)")); @@ -16185,11 +16057,7 @@ LIMIT {{ limit }}{% endif %}"#.to_string(), .await; let logical_plan = query_plan.as_logical_plan(); - let sql = logical_plan - .find_cube_scan_wrapper() - .wrapped_sql - .unwrap() - .sql; + let sql = logical_plan.find_cube_scan_wrapped_sql().wrapped_sql.sql; assert!(sql.contains("LIKE ")); assert!(sql.contains("ESCAPE ")); @@ -16326,11 +16194,7 @@ LIMIT {{ limit }}{% endif %}"#.to_string(), ); let logical_plan = query_plan.as_logical_plan(); - let sql = logical_plan - .find_cube_scan_wrapper() - .wrapped_sql - .unwrap() - .sql; + let sql = logical_plan.find_cube_scan_wrapped_sql().wrapped_sql.sql; assert!(sql.contains(" IS NULL DESC, ")); } } diff --git a/rust/cubesql/cubesql/src/compile/query_engine.rs b/rust/cubesql/cubesql/src/compile/query_engine.rs index 2532fa12efc48..629ed754e5732 100644 --- a/rust/cubesql/cubesql/src/compile/query_engine.rs +++ b/rust/cubesql/cubesql/src/compile/query_engine.rs @@ -249,6 +249,8 @@ pub trait QueryEngine { }; log::debug!("Rewrite: {:#?}", rewrite_plan); + // We want to generate SQL early, as a part of planning, and not later (like during execution) + // to catch all SQL generation errors during planning let rewrite_plan = Self::evaluate_wrapped_sql( self.transport_ref().clone(), Arc::new(state.get_load_request_meta()), diff --git a/rust/cubesql/cubesql/src/compile/test/test_wrapper.rs b/rust/cubesql/cubesql/src/compile/test/test_wrapper.rs index 5f5f039d0944d..05c25b3f15702 100644 --- a/rust/cubesql/cubesql/src/compile/test/test_wrapper.rs +++ b/rust/cubesql/cubesql/src/compile/test/test_wrapper.rs @@ -32,9 +32,8 @@ async fn test_simple_wrapper() { let logical_plan = query_plan.as_logical_plan(); assert!(logical_plan - .find_cube_scan_wrapper() + .find_cube_scan_wrapped_sql() .wrapped_sql - .unwrap() .sql .contains("COALESCE")); @@ -56,11 +55,7 @@ async fn test_wrapper_group_by_rollup() { .await; let logical_plan = query_plan.as_logical_plan(); - let sql = logical_plan - .find_cube_scan_wrapper() - .wrapped_sql - .unwrap() - .sql; + let sql = logical_plan.find_cube_scan_wrapped_sql().wrapped_sql.sql; assert!(sql.contains("Rollup")); let _physical_plan = query_plan.as_physical_plan().await.unwrap(); @@ -81,11 +76,7 @@ async fn test_wrapper_group_by_rollup_with_aliases() { .await; let logical_plan = query_plan.as_logical_plan(); - let sql = logical_plan - .find_cube_scan_wrapper() - .wrapped_sql - .unwrap() - .sql; + let sql = logical_plan.find_cube_scan_wrapped_sql().wrapped_sql.sql; assert!(sql.contains("Rollup")); let _physical_plan = query_plan.as_physical_plan().await.unwrap(); @@ -106,11 +97,7 @@ async fn test_wrapper_group_by_rollup_nested() { .await; let logical_plan = query_plan.as_logical_plan(); - let sql = logical_plan - .find_cube_scan_wrapper() - .wrapped_sql - .unwrap() - .sql; + let sql = logical_plan.find_cube_scan_wrapped_sql().wrapped_sql.sql; assert!(sql.contains("ROLLUP(1, 2)")); let _physical_plan = query_plan.as_physical_plan().await.unwrap(); @@ -131,11 +118,7 @@ async fn test_wrapper_group_by_rollup_nested_from_asterisk() { .await; let logical_plan = query_plan.as_logical_plan(); - let sql = logical_plan - .find_cube_scan_wrapper() - .wrapped_sql - .unwrap() - .sql; + let sql = logical_plan.find_cube_scan_wrapped_sql().wrapped_sql.sql; assert!(sql.contains("Rollup")); let _physical_plan = query_plan.as_physical_plan().await.unwrap(); @@ -156,11 +139,7 @@ async fn test_wrapper_group_by_rollup_nested_with_aliases() { .await; let logical_plan = query_plan.as_logical_plan(); - let sql = logical_plan - .find_cube_scan_wrapper() - .wrapped_sql - .unwrap() - .sql; + let sql = logical_plan.find_cube_scan_wrapped_sql().wrapped_sql.sql; assert!(sql.contains("ROLLUP(1, 2)")); let _physical_plan = query_plan.as_physical_plan().await.unwrap(); @@ -183,11 +162,7 @@ async fn test_wrapper_group_by_rollup_nested_complex() { .await; let logical_plan = query_plan.as_logical_plan(); - let sql = logical_plan - .find_cube_scan_wrapper() - .wrapped_sql - .unwrap() - .sql; + let sql = logical_plan.find_cube_scan_wrapped_sql().wrapped_sql.sql; assert!(sql.contains("ROLLUP(1), ROLLUP(2), 3, CUBE(4)")); let _physical_plan = query_plan.as_physical_plan().await.unwrap(); @@ -208,11 +183,7 @@ async fn test_wrapper_group_by_rollup_placeholders() { .await; let logical_plan = query_plan.as_logical_plan(); - let sql = logical_plan - .find_cube_scan_wrapper() - .wrapped_sql - .unwrap() - .sql; + let sql = logical_plan.find_cube_scan_wrapped_sql().wrapped_sql.sql; assert!(sql.contains("Rollup")); let _physical_plan = query_plan.as_physical_plan().await.unwrap(); @@ -233,11 +204,7 @@ async fn test_wrapper_group_by_cube() { .await; let logical_plan = query_plan.as_logical_plan(); - let sql = logical_plan - .find_cube_scan_wrapper() - .wrapped_sql - .unwrap() - .sql; + let sql = logical_plan.find_cube_scan_wrapped_sql().wrapped_sql.sql; assert!(sql.contains("Cube")); let _physical_plan = query_plan.as_physical_plan().await.unwrap(); @@ -258,11 +225,7 @@ async fn test_wrapper_group_by_rollup_complex() { .await; let logical_plan = query_plan.as_logical_plan(); - let sql = logical_plan - .find_cube_scan_wrapper() - .wrapped_sql - .unwrap() - .sql; + let sql = logical_plan.find_cube_scan_wrapped_sql().wrapped_sql.sql; assert!(sql.contains("Rollup")); let _physical_plan = query_plan.as_physical_plan().await.unwrap(); @@ -283,11 +246,7 @@ async fn test_simple_subquery_wrapper_projection_empty_source() { .await; let logical_plan = query_plan.as_logical_plan(); - let sql = logical_plan - .find_cube_scan_wrapper() - .wrapped_sql - .unwrap() - .sql; + let sql = logical_plan.find_cube_scan_wrapped_sql().wrapped_sql.sql; assert!(sql.contains("(SELECT")); assert!(sql.contains("utf8__male__")); @@ -310,11 +269,7 @@ async fn test_simple_subquery_wrapper_filter_empty_source() { .await; let logical_plan = query_plan.as_logical_plan(); - let sql = logical_plan - .find_cube_scan_wrapper() - .wrapped_sql - .unwrap() - .sql; + let sql = logical_plan.find_cube_scan_wrapped_sql().wrapped_sql.sql; assert!(sql.contains("(SELECT")); assert!(sql.contains("utf8__male__")); @@ -337,11 +292,7 @@ async fn test_simple_subquery_wrapper_projection_aggregate_empty_source() { .await; let logical_plan = query_plan.as_logical_plan(); - let sql = logical_plan - .find_cube_scan_wrapper() - .wrapped_sql - .unwrap() - .sql; + let sql = logical_plan.find_cube_scan_wrapped_sql().wrapped_sql.sql; assert!(sql.contains("(SELECT")); assert!(sql.contains("utf8__male__")); @@ -363,11 +314,7 @@ async fn test_simple_subquery_wrapper_filter_in_empty_source() { .await; let logical_plan = query_plan.as_logical_plan(); - let sql = logical_plan - .find_cube_scan_wrapper() - .wrapped_sql - .unwrap() - .sql; + let sql = logical_plan.find_cube_scan_wrapped_sql().wrapped_sql.sql; assert!(sql.contains("IN (SELECT")); assert!(sql.contains("utf8__male__")); @@ -390,11 +337,7 @@ async fn test_simple_subquery_wrapper_filter_and_projection_empty_source() { let logical_plan = query_plan.as_logical_plan(); - let sql = logical_plan - .find_cube_scan_wrapper() - .wrapped_sql - .unwrap() - .sql; + let sql = logical_plan.find_cube_scan_wrapped_sql().wrapped_sql.sql; assert!(sql.contains("IN (SELECT")); assert!(sql.contains("(SELECT")); assert!(sql.contains("utf8__male__")); @@ -419,15 +362,13 @@ async fn test_simple_subquery_wrapper_projection() { let logical_plan = query_plan.as_logical_plan(); assert!(logical_plan - .find_cube_scan_wrapper() + .find_cube_scan_wrapped_sql() .wrapped_sql - .unwrap() .sql .contains("(SELECT")); assert!(logical_plan - .find_cube_scan_wrapper() + .find_cube_scan_wrapped_sql() .wrapped_sql - .unwrap() .sql .contains("\\\\\\\"limit\\\\\\\": 1")); @@ -450,9 +391,8 @@ async fn test_simple_subquery_wrapper_projection_aggregate() { let logical_plan = query_plan.as_logical_plan(); assert!(logical_plan - .find_cube_scan_wrapper() + .find_cube_scan_wrapped_sql() .wrapped_sql - .unwrap() .sql .contains("(SELECT")); @@ -475,15 +415,13 @@ async fn test_simple_subquery_wrapper_filter_equal() { let logical_plan = query_plan.as_logical_plan(); assert!(logical_plan - .find_cube_scan_wrapper() + .find_cube_scan_wrapped_sql() .wrapped_sql - .unwrap() .sql .contains("(SELECT")); assert!(logical_plan - .find_cube_scan_wrapper() + .find_cube_scan_wrapped_sql() .wrapped_sql - .unwrap() .sql .contains("\\\\\\\"limit\\\\\\\": 1")); @@ -506,9 +444,8 @@ async fn test_simple_subquery_wrapper_filter_in() { let logical_plan = query_plan.as_logical_plan(); assert!(logical_plan - .find_cube_scan_wrapper() + .find_cube_scan_wrapped_sql() .wrapped_sql - .unwrap() .sql .contains("IN (SELECT")); @@ -532,9 +469,8 @@ async fn test_simple_subquery_wrapper_filter_and_projection() { let logical_plan = query_plan.as_logical_plan(); assert!(logical_plan - .find_cube_scan_wrapper() + .find_cube_scan_wrapped_sql() .wrapped_sql - .unwrap() .sql .contains("IN (SELECT")); @@ -580,9 +516,8 @@ GROUP BY assert!(query_plan .as_logical_plan() - .find_cube_scan_wrapper() + .find_cube_scan_wrapped_sql() .wrapped_sql - .unwrap() .sql .contains( "${KibanaSampleDataEcommerce.order_date} >= timestamptz '2024-02-03T04:05:06.000Z'" @@ -626,9 +561,8 @@ WHERE assert!(query_plan .as_logical_plan() - .find_cube_scan_wrapper() + .find_cube_scan_wrapped_sql() .wrapped_sql - .unwrap() .sql .contains( "${KibanaSampleDataEcommerce.order_date} >= timestamptz '2024-02-03T04:05:06.000Z'" @@ -671,9 +605,8 @@ GROUP BY assert!(query_plan .as_logical_plan() - .find_cube_scan_wrapper() + .find_cube_scan_wrapped_sql() .wrapped_sql - .unwrap() .sql .contains("${KibanaSampleDataEcommerce.order_date} >= timestamptz")); } @@ -712,9 +645,8 @@ WHERE assert!(query_plan .as_logical_plan() - .find_cube_scan_wrapper() + .find_cube_scan_wrapped_sql() .wrapped_sql - .unwrap() .sql .contains("${KibanaSampleDataEcommerce.order_date} >= timestamptz")); } @@ -735,9 +667,8 @@ async fn test_case_wrapper() { let logical_plan = query_plan.as_logical_plan(); assert!(logical_plan - .find_cube_scan_wrapper() + .find_cube_scan_wrapped_sql() .wrapped_sql - .unwrap() .sql .contains("CASE WHEN")); @@ -773,9 +704,8 @@ async fn test_case_wrapper_distinct() { let logical_plan = query_plan.as_logical_plan(); assert!(logical_plan - .find_cube_scan_wrapper() + .find_cube_scan_wrapped_sql() .wrapped_sql - .unwrap() .sql .contains("CASE WHEN")); @@ -802,9 +732,8 @@ async fn test_case_wrapper_alias_with_order() { let logical_plan = query_plan.as_logical_plan(); assert!(logical_plan - .find_cube_scan_wrapper() + .find_cube_scan_wrapped_sql() .wrapped_sql - .unwrap() .sql .contains("ORDER BY \"case_when_a_cust\"")); @@ -831,9 +760,8 @@ async fn test_case_wrapper_ungrouped() { let logical_plan = query_plan.as_logical_plan(); assert!(logical_plan - .find_cube_scan_wrapper() + .find_cube_scan_wrapped_sql() .wrapped_sql - .unwrap() .sql .contains("CASE WHEN")); @@ -865,9 +793,8 @@ async fn test_case_wrapper_non_strict_match() { let logical_plan = query_plan.as_logical_plan(); assert!(logical_plan - .find_cube_scan_wrapper() + .find_cube_scan_wrapped_sql() .wrapped_sql - .unwrap() .sql .contains("CASE WHEN")); @@ -900,9 +827,8 @@ async fn test_case_wrapper_ungrouped_sorted() { let logical_plan = query_plan.as_logical_plan(); assert!(logical_plan - .find_cube_scan_wrapper() + .find_cube_scan_wrapped_sql() .wrapped_sql - .unwrap() .sql .contains("ORDER BY")); } @@ -929,9 +855,8 @@ async fn test_case_wrapper_ungrouped_sorted_aliased() { let logical_plan = query_plan.as_logical_plan(); assert!(logical_plan - .find_cube_scan_wrapper() + .find_cube_scan_wrapped_sql() .wrapped_sql - .unwrap() .sql // TODO test without depend on column name .contains("ORDER BY \"case_when")); @@ -953,25 +878,19 @@ async fn test_case_wrapper_with_internal_limit() { let logical_plan = query_plan.as_logical_plan(); assert!(logical_plan - .find_cube_scan_wrapper() + .find_cube_scan_wrapped_sql() .wrapped_sql - .unwrap() .sql .contains("CASE WHEN")); assert!( logical_plan - .find_cube_scan_wrapper() + .find_cube_scan_wrapped_sql() .wrapped_sql - .unwrap() .sql .contains("1123"), "SQL contains 1123: {}", - logical_plan - .find_cube_scan_wrapper() - .wrapped_sql - .unwrap() - .sql + logical_plan.find_cube_scan_wrapped_sql().wrapped_sql.sql ); let physical_plan = query_plan.as_physical_plan().await.unwrap(); @@ -999,19 +918,14 @@ async fn test_case_wrapper_with_system_fields() { assert!( logical_plan - .find_cube_scan_wrapper() + .find_cube_scan_wrapped_sql() .wrapped_sql - .unwrap() .sql .contains( "\\\"cube_name\\\":\\\"KibanaSampleDataEcommerce\\\",\\\"alias\\\":\\\"user\\\"" ), r#"SQL contains `\"cube_name\":\"KibanaSampleDataEcommerce\",\"alias\":\"user\"` {}"#, - logical_plan - .find_cube_scan_wrapper() - .wrapped_sql - .unwrap() - .sql + logical_plan.find_cube_scan_wrapped_sql().wrapped_sql.sql ); let physical_plan = query_plan.as_physical_plan().await.unwrap(); @@ -1037,25 +951,19 @@ async fn test_case_wrapper_with_limit() { let logical_plan = query_plan.as_logical_plan(); assert!(logical_plan - .find_cube_scan_wrapper() + .find_cube_scan_wrapped_sql() .wrapped_sql - .unwrap() .sql .contains("CASE WHEN")); assert!( logical_plan - .find_cube_scan_wrapper() + .find_cube_scan_wrapped_sql() .wrapped_sql - .unwrap() .sql .contains("1123"), "SQL contains 1123: {}", - logical_plan - .find_cube_scan_wrapper() - .wrapped_sql - .unwrap() - .sql + logical_plan.find_cube_scan_wrapped_sql().wrapped_sql.sql ); let physical_plan = query_plan.as_physical_plan().await.unwrap(); @@ -1081,9 +989,8 @@ async fn test_case_wrapper_with_null() { let logical_plan = query_plan.as_logical_plan(); assert!(logical_plan - .find_cube_scan_wrapper() + .find_cube_scan_wrapped_sql() .wrapped_sql - .unwrap() .sql .contains("CASE WHEN")); @@ -1138,9 +1045,8 @@ async fn test_case_wrapper_escaping() { let logical_plan = query_plan.as_logical_plan(); assert!(logical_plan - .find_cube_scan_wrapper() + .find_cube_scan_wrapped_sql() .wrapped_sql - .unwrap() .sql // Expect 6 backslashes as output is JSON and it's escaped one more time .contains("\\\\\\\\\\\\`")); @@ -1174,9 +1080,8 @@ async fn test_wrapper_limit_zero() { let logical_plan = query_plan.as_logical_plan(); assert!(logical_plan - .find_cube_scan_wrapper() + .find_cube_scan_wrapped_sql() .wrapped_sql - .unwrap() .sql .contains("LIMIT 0")); @@ -1219,9 +1124,8 @@ async fn test_wrapper_filter_flatten() { assert_eq!( query_plan .as_logical_plan() - .find_cube_scan_wrapper() - .request - .unwrap(), + .find_cube_scan_wrapped_sql() + .request, TransportLoadRequestQuery { measures: Some(vec![json!({ "cube_name": "KibanaSampleDataEcommerce", diff --git a/rust/cubesql/cubesql/src/compile/test/utils.rs b/rust/cubesql/cubesql/src/compile/test/utils.rs index 5193918dc97b4..e22772a655b61 100644 --- a/rust/cubesql/cubesql/src/compile/test/utils.rs +++ b/rust/cubesql/cubesql/src/compile/test/utils.rs @@ -3,14 +3,17 @@ use std::sync::Arc; use datafusion::logical_plan::{plan::Extension, Filter, LogicalPlan, PlanVisitor}; use crate::{ - compile::engine::df::{scan::CubeScanNode, wrapper::CubeScanWrapperNode}, + compile::engine::df::{ + scan::CubeScanNode, + wrapper::{CubeScanWrappedSqlNode, CubeScanWrapperNode}, + }, CubeError, }; pub trait LogicalPlanTestUtils { fn find_cube_scan(&self) -> CubeScanNode; - fn find_cube_scan_wrapper(&self) -> CubeScanWrapperNode; + fn find_cube_scan_wrapped_sql(&self) -> CubeScanWrappedSqlNode; fn find_cube_scans(&self) -> Vec; @@ -27,13 +30,13 @@ impl LogicalPlanTestUtils for LogicalPlan { cube_scans[0].clone() } - fn find_cube_scan_wrapper(&self) -> CubeScanWrapperNode { + fn find_cube_scan_wrapped_sql(&self) -> CubeScanWrappedSqlNode { match self { LogicalPlan::Extension(Extension { node }) => { - if let Some(wrapper_node) = node.as_any().downcast_ref::() { + if let Some(wrapper_node) = node.as_any().downcast_ref::() { wrapper_node.clone() } else { - panic!("Root plan node is not cube_scan_wrapper!"); + panic!("Root plan node is not cube_scan_wrapped_sql!"); } } _ => panic!("Root plan node is not extension!"), @@ -66,6 +69,10 @@ pub fn find_cube_scans_deep_search( ext.node.as_any().downcast_ref::() { wrapper_node.wrapped_plan.accept(self)?; + } else if let Some(wrapper_node) = + ext.node.as_any().downcast_ref::() + { + wrapper_node.wrapped_plan.accept(self)?; } } Ok(true)