Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(cubesql): Extract CubeScanWrappedSqlNode from CubeScanWrapperNode #8786

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 14 additions & 17 deletions rust/cubesql/cubesql/src/compile/engine/df/scan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@

use crate::{
compile::{
engine::df::wrapper::{CubeScanWrapperNode, SqlQuery},
engine::df::wrapper::{CubeScanWrappedSqlNode, CubeScanWrapperNode, SqlQuery},
rewrite::WrappedSelectType,
test::find_cube_scans_deep_search,
},
Expand Down Expand Up @@ -394,35 +394,32 @@
config_obj: self.config_obj.clone(),
}))
} else if let Some(wrapper_node) = node.as_any().downcast_ref::<CubeScanWrapperNode>() {
return Err(DataFusionError::Internal(format!(
"CubeScanWrapperNode is not executable, SQL should be generated first with QueryEngine::evaluate_wrapped_sql: {:?}",
wrapper_node
)));

Check warning on line 400 in rust/cubesql/cubesql/src/compile/engine/df/scan.rs

View check run for this annotation

Codecov / codecov/patch

rust/cubesql/cubesql/src/compile/engine/df/scan.rs#L397-L400

Added lines #L397 - L400 were not covered by tests
} else if let Some(wrapped_sql_node) =
node.as_any().downcast_ref::<CubeScanWrappedSqlNode>()
{
// TODO
// assert_eq!(logical_inputs.len(), 0, "Inconsistent number of inputs");
// assert_eq!(physical_inputs.len(), 0, "Inconsistent number of inputs");
let scan_node =
find_cube_scans_deep_search(wrapper_node.wrapped_plan.clone(), false)
find_cube_scans_deep_search(wrapped_sql_node.wrapped_plan.clone(), false)
.into_iter()
.next()
.ok_or(DataFusionError::Internal(format!(
"No cube scans found in wrapper node: {:?}",
wrapper_node
wrapped_sql_node
)))?;

let schema = SchemaRef::new(wrapper_node.schema().as_ref().into());
let schema = SchemaRef::new(wrapped_sql_node.schema().as_ref().into());
Some(Arc::new(CubeScanExecutionPlan {
schema,
member_fields: wrapper_node.member_fields.as_ref().ok_or_else(|| {
DataFusionError::Internal(format!(
"Member fields are not set for wrapper node. Optimization wasn't performed: {:?}",
wrapper_node
))
})?.clone(),
member_fields: wrapped_sql_node.member_fields.clone(),
transport: self.transport.clone(),
request: wrapper_node.request.clone().unwrap_or(scan_node.request.clone()),
wrapped_sql: Some(wrapper_node.wrapped_sql.as_ref().ok_or_else(|| {
DataFusionError::Internal(format!(
"Wrapped SQL is not set for wrapper node. Optimization wasn't performed: {:?}",
wrapper_node
))
})?.clone()),
request: wrapped_sql_node.request.clone(),
wrapped_sql: Some(wrapped_sql_node.wrapped_sql.clone()),
auth_context: scan_node.auth_context.clone(),
options: scan_node.options.clone(),
meta: self.meta.clone(),
Expand Down
100 changes: 71 additions & 29 deletions rust/cubesql/cubesql/src/compile/engine/df/wrapper.rs
Original file line number Diff line number Diff line change
Expand Up @@ -201,14 +201,75 @@
}
}

#[derive(Clone, Debug)]
pub struct CubeScanWrappedSqlNode {
// TODO maybe replace wrapped plan with schema + scan_node
pub wrapped_plan: Arc<LogicalPlan>,
pub wrapped_sql: SqlQuery,
pub request: TransportLoadRequestQuery,
pub member_fields: Vec<MemberField>,
}

impl CubeScanWrappedSqlNode {
pub fn new(
wrapped_plan: Arc<LogicalPlan>,
sql: SqlQuery,
request: TransportLoadRequestQuery,
member_fields: Vec<MemberField>,
) -> Self {
Self {
wrapped_plan,
wrapped_sql: sql,
request,
member_fields,
}
}
}

impl UserDefinedLogicalNode for CubeScanWrappedSqlNode {
fn as_any(&self) -> &dyn Any {
self
}

fn inputs(&self) -> Vec<&LogicalPlan> {
vec![]
}

fn schema(&self) -> &DFSchemaRef {
self.wrapped_plan.schema()
}

fn expressions(&self) -> Vec<Expr> {
vec![]
}

fn fmt_for_explain(&self, f: &mut fmt::Formatter) -> fmt::Result {
// TODO figure out nice plan for wrapped plan
write!(f, "CubeScanWrappedSql")
}

Check warning on line 249 in rust/cubesql/cubesql/src/compile/engine/df/wrapper.rs

View check run for this annotation

Codecov / codecov/patch

rust/cubesql/cubesql/src/compile/engine/df/wrapper.rs#L246-L249

Added lines #L246 - L249 were not covered by tests

fn from_template(
&self,
exprs: &[datafusion::logical_plan::Expr],
inputs: &[datafusion::logical_plan::LogicalPlan],
) -> std::sync::Arc<dyn UserDefinedLogicalNode + Send + Sync> {
assert_eq!(inputs.len(), 0, "input size inconsistent");
assert_eq!(exprs.len(), 0, "expression size inconsistent");

Arc::new(CubeScanWrappedSqlNode {
wrapped_plan: self.wrapped_plan.clone(),
wrapped_sql: self.wrapped_sql.clone(),
request: self.request.clone(),
member_fields: self.member_fields.clone(),
})
}
}

#[derive(Debug, Clone)]
pub struct CubeScanWrapperNode {
pub wrapped_plan: Arc<LogicalPlan>,
pub meta: Arc<MetaContext>,
pub auth_context: AuthContextRef,
pub wrapped_sql: Option<SqlQuery>,
pub request: Option<TransportLoadRequestQuery>,
pub member_fields: Option<Vec<MemberField>>,
pub span_id: Option<Arc<SpanId>>,
pub config_obj: Arc<dyn ConfigObj>,
}
Expand All @@ -225,31 +286,10 @@
wrapped_plan,
meta,
auth_context,
wrapped_sql: None,
request: None,
member_fields: None,
span_id,
config_obj,
}
}

pub fn with_sql_and_request(
&self,
sql: SqlQuery,
request: TransportLoadRequestQuery,
member_fields: Vec<MemberField>,
) -> Self {
Self {
wrapped_plan: self.wrapped_plan.clone(),
meta: self.meta.clone(),
auth_context: self.auth_context.clone(),
wrapped_sql: Some(sql),
request: Some(request),
member_fields: Some(member_fields),
span_id: self.span_id.clone(),
config_obj: self.config_obj.clone(),
}
}
}

fn expr_name(e: &Expr, schema: &Arc<DFSchema>) -> Result<String> {
Expand Down Expand Up @@ -317,7 +357,7 @@
&self,
transport: Arc<dyn TransportService>,
load_request_meta: Arc<LoadRequestMeta>,
) -> result::Result<Self, CubeError> {
) -> result::Result<CubeScanWrappedSqlNode, CubeError> {
let schema = self.schema();
let wrapped_plan = self.wrapped_plan.clone();
let (sql, request, member_fields) = Self::generate_sql_for_node(
Expand Down Expand Up @@ -361,7 +401,12 @@
sql.finalize_query(sql_templates).map_err(|e| CubeError::internal(e.to_string()))?;
Ok((sql, request, member_fields))
})?;
Ok(self.with_sql_and_request(sql, request, member_fields))
Ok(CubeScanWrappedSqlNode::new(
self.wrapped_plan.clone(),
sql,
request,
member_fields,
))
}

pub fn set_max_limit_for_node(self, node: Arc<LogicalPlan>) -> Arc<LogicalPlan> {
Expand Down Expand Up @@ -2231,9 +2276,6 @@
wrapped_plan: self.wrapped_plan.clone(),
meta: self.meta.clone(),
auth_context: self.auth_context.clone(),
wrapped_sql: self.wrapped_sql.clone(),
request: self.request.clone(),
member_fields: self.member_fields.clone(),
span_id: self.span_id.clone(),
config_obj: self.config_obj.clone(),
})
Expand Down
Loading
Loading