Skip to content

Commit

Permalink
implement the scope analysis and refactor model analysis
Browse files Browse the repository at this point in the history
  • Loading branch information
goldmedal committed Sep 11, 2024
1 parent 32fcbe2 commit ed135a2
Show file tree
Hide file tree
Showing 5 changed files with 194 additions and 62 deletions.
76 changes: 31 additions & 45 deletions wren-modeling-rs/core/src/logical_plan/analyze/model_anlayze.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use crate::logical_plan::analyze::plan::ModelPlanNode;
use crate::logical_plan::utils::belong_to_mdl;
use crate::logical_plan::utils::{belong_to_mdl, expr_to_columns};
use crate::mdl::utils::quoted;
use crate::mdl::{AnalyzedWrenMDL, Dataset, SessionStateRef};
use datafusion::catalog_common::TableReference;
Expand All @@ -8,7 +8,6 @@ use datafusion::common::{internal_err, plan_err, Column, DFSchemaRef, Result};
use datafusion::config::ConfigOptions;
use datafusion::error::DataFusionError;
use datafusion::logical_expr::expr::Alias;
use datafusion::logical_expr::utils::expr_to_columns;
use datafusion::logical_expr::{
col, ident, Aggregate, Distinct, DistinctOn, Expr, Extension, Filter, Join,
LogicalPlan, LogicalPlanBuilder, Projection, Subquery, SubqueryAlias, TableScan,
Expand Down Expand Up @@ -49,17 +48,24 @@ impl ModelAnalyzeRule {
scope_buffer: &RefCell<VecDeque<RefCell<Scope>>>,
) -> Result<Transformed<LogicalPlan>> {
plan.transform_up(&|plan| -> Result<Transformed<LogicalPlan>> {
let plan = self.analyze_scope_internal(plan, &root, scope_buffer)?.data;
let plan = self.analyze_scope_internal(plan, root)?.data;
plan.map_subqueries(|plan| {
if let LogicalPlan::Subquery(Subquery { subquery, .. }) = &plan {
if let LogicalPlan::Subquery(Subquery {
subquery,
outer_ref_columns,
}) = &plan
{
outer_ref_columns.iter().try_for_each(|expr| {
let mut scope_mut = root.borrow_mut();
self.collect_required_column(expr.clone(), &mut scope_mut)
})?;
let child_scope =
RefCell::new(Scope::new_child(RefCell::clone(root)));
self.analyze_scope(
Arc::unwrap_or_clone(Arc::clone(subquery)),
&child_scope,
scope_buffer,
)?
.data;
)?;
let mut scope_buffer = scope_buffer.borrow_mut();
scope_buffer.push_back(child_scope);
}
Expand All @@ -73,7 +79,6 @@ impl ModelAnalyzeRule {
&self,
plan: LogicalPlan,
scope: &RefCell<Scope>,
scope_buffer: &RefCell<VecDeque<RefCell<Scope>>>,
) -> Result<Transformed<LogicalPlan>> {
match &plan {
LogicalPlan::TableScan(table_scan) => {
Expand Down Expand Up @@ -155,18 +160,6 @@ impl ModelAnalyzeRule {
})?;
Ok(Transformed::no(plan))
}
LogicalPlan::Subquery(subquery) => {
subquery.outer_ref_columns.iter().try_for_each(|expr| {
let mut scope_mut = scope.borrow_mut();
self.collect_required_column(expr.clone(), &mut scope_mut)
})?;
// create a new scope for the subquery
let child_scope = RefCell::new(Scope::new_child(RefCell::clone(&scope)));
let plan = self.analyze_scope(plan, &child_scope, scope_buffer)?.data;
let mut scope_buffer = scope_buffer.borrow_mut();
scope_buffer.push_back(child_scope);
Ok(Transformed::no(plan))
}
LogicalPlan::SubqueryAlias(subquery_alias) => {
let mut scope_mut = scope.borrow_mut();
if let LogicalPlan::TableScan(table_scan) =
Expand Down Expand Up @@ -211,18 +204,16 @@ impl ModelAnalyzeRule {
&self.analyzed_wren_mdl.wren_mdl(),
relation.clone(),
Arc::clone(&self.session_state),
) {
if self
.analyzed_wren_mdl
.wren_mdl()
.get_view(relation.table())
.is_none()
{
scope.add_required_column(
relation.clone(),
Expr::Column(Column::new(Some(relation), name)),
)?;
}
) && self
.analyzed_wren_mdl
.wren_mdl()
.get_view(relation.table())
.is_none()
{
scope.add_required_column(
relation.clone(),
Expr::Column(Column::new(Some(relation), name)),
)?;
}
}
// It is possible that the column is a rebase column from the aggregation or join
Expand Down Expand Up @@ -353,9 +344,8 @@ impl ModelAnalyzeRule {
scope,
)?
.data;
let subquery = LogicalPlanBuilder::from(model_plan)
.alias(alias)?
.build()?;
let subquery =
LogicalPlanBuilder::from(model_plan).alias(alias)?.build()?;
Ok(Transformed::yes(subquery))
}
_ => Ok(Transformed::no(LogicalPlan::SubqueryAlias(
Expand Down Expand Up @@ -714,10 +704,10 @@ impl ModelAnalyzeRule {

impl AnalyzerRule for ModelAnalyzeRule {
fn analyze(&self, plan: LogicalPlan, _: &ConfigOptions) -> Result<LogicalPlan> {
let mut queue = RefCell::new(VecDeque::new());
let mut root = RefCell::new(Scope::new());
self.analyze_scope(plan, &mut root, &mut queue)?
.map_data(|plan| self.analyze_model(plan, &root, &mut queue).data())?
let queue = RefCell::new(VecDeque::new());
let root = RefCell::new(Scope::new());
self.analyze_scope(plan, &root, &queue)?
.map_data(|plan| self.analyze_model(plan, &root, &queue).data())?
.map_data(|plan| {
plan.transform_up_with_subqueries(&|plan| -> Result<
Transformed<LogicalPlan>,
Expand All @@ -734,7 +724,7 @@ impl AnalyzerRule for ModelAnalyzeRule {
"ModelAnalyzeRule"
}
}
#[derive(Clone)]
#[derive(Clone, Debug, Default)]
pub struct Scope {
/// The columns required by the dataset
required_columns: HashMap<TableReference, HashSet<Expr>>,
Expand Down Expand Up @@ -770,7 +760,7 @@ impl Scope {
if self.visited_dataset.contains_key(&table_ref) {
self.required_columns
.entry(table_ref)
.or_insert(HashSet::new())
.or_default()
.insert(expr);
Ok(())
} else if let Some(ref parent) = &self.parent {
Expand All @@ -787,11 +777,7 @@ impl Scope {
}
}

pub fn add_visited_dataset(
&mut self,
table_ref: TableReference,
dataset: Dataset,
) {
pub fn add_visited_dataset(&mut self, table_ref: TableReference, dataset: Dataset) {
self.visited_dataset.insert(table_ref, dataset);
}

Expand Down
62 changes: 59 additions & 3 deletions wren-modeling-rs/core/src/logical_plan/utils.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
use std::{collections::HashMap, sync::Arc};

use datafusion::arrow::datatypes::{
DataType, Field, IntervalUnit, Schema, SchemaBuilder, SchemaRef, TimeUnit,
};
use datafusion::catalog_common::TableReference;
use datafusion::common::tree_node::{TreeNode, TreeNodeRecursion};
use datafusion::datasource::DefaultTableSource;
use datafusion::error::Result;
use datafusion::logical_expr::{builder::LogicalTableSource, TableSource};
use datafusion::logical_expr::{builder::LogicalTableSource, Expr, TableSource};
use log::debug;
use petgraph::dot::{Config, Dot};
use petgraph::Graph;
use std::collections::HashSet;
use std::{collections::HashMap, sync::Arc};

use crate::mdl::lineage::DatasetLink;
use crate::mdl::utils::quoted;
Expand Down Expand Up @@ -126,6 +127,7 @@ pub fn create_remote_table_source(model: &Model, mdl: &WrenMDL) -> Arc<dyn Table
} else {
column.name.clone()
};
// TODO: find a way for the remote table to provide the data type
// We don't know the data type of the remote table, so we just mock a Int32 type here
Field::new(name, DataType::Int8, column.no_null)
})
Expand Down Expand Up @@ -196,6 +198,60 @@ pub fn belong_to_mdl(
catalog_match && schema_match
}

/// Collect all the Columns and OuterReferenceColumns in the expression
pub fn expr_to_columns(
expr: &Expr,
accum: &mut HashSet<datafusion::common::Column>,
) -> Result<()> {
expr.apply(|expr| {
match expr {
Expr::Column(qc) => {
accum.insert(qc.clone());
}
Expr::OuterReferenceColumn(_, column) => {
accum.insert(column.clone());
}
// Use explicit pattern match instead of a default
// implementation, so that in the future if someone adds
// new Expr types, they will check here as well
Expr::Unnest(_)
| Expr::ScalarVariable(_, _)
| Expr::Alias(_)
| Expr::Literal(_)
| Expr::BinaryExpr { .. }
| Expr::Like { .. }
| Expr::SimilarTo { .. }
| Expr::Not(_)
| Expr::IsNotNull(_)
| Expr::IsNull(_)
| Expr::IsTrue(_)
| Expr::IsFalse(_)
| Expr::IsUnknown(_)
| Expr::IsNotTrue(_)
| Expr::IsNotFalse(_)
| Expr::IsNotUnknown(_)
| Expr::Negative(_)
| Expr::Between { .. }
| Expr::Case { .. }
| Expr::Cast { .. }
| Expr::TryCast { .. }
| Expr::Sort { .. }
| Expr::ScalarFunction(..)
| Expr::WindowFunction { .. }
| Expr::AggregateFunction { .. }
| Expr::GroupingSet(_)
| Expr::InList { .. }
| Expr::Exists { .. }
| Expr::InSubquery(_)
| Expr::ScalarSubquery(_)
| Expr::Wildcard { .. }
| Expr::Placeholder(_) => {}
}
Ok(TreeNodeRecursion::Continue)
})
.map(|_| ())
}

#[cfg(test)]
mod test {
use datafusion::arrow::datatypes::{DataType, IntervalUnit, TimeUnit};
Expand Down
92 changes: 88 additions & 4 deletions wren-modeling-rs/sqllogictest/src/test_context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,7 @@ async fn register_ecommerce_mdl(

pub async fn register_tpch_table(ctx: &SessionContext) -> Result<TestContext> {
let path = PathBuf::from(TEST_RESOURCES).join("tpch");
let data = read_dir_recursive(&path).unwrap();
let data = read_dir_recursive(&path)?;

// register parquet file with the execution context
for file in data.iter() {
Expand All @@ -317,8 +317,7 @@ pub async fn register_tpch_table(ctx: &SessionContext) -> Result<TestContext> {
file.to_str().unwrap(),
ParquetReadOptions::default(),
)
.await
.unwrap();
.await?;
}
let (ctx, mdl) = register_tpch_mdl(ctx).await?;

Expand Down Expand Up @@ -446,7 +445,92 @@ async fn register_tpch_mdl(
.build(),
)
.build();
let analyzed_mdl = Arc::new(AnalyzedWrenMDL::analyze(manifest)?);
let mut register_tables = HashMap::new();
register_tables.insert(
"datafusion.public.customer".to_string(),
ctx.catalog("datafusion")
.unwrap()
.schema("public")
.unwrap()
.table("customer")
.await?
.unwrap(),
);
register_tables.insert(
"datafusion.public.orders".to_string(),
ctx.catalog("datafusion")
.unwrap()
.schema("public")
.unwrap()
.table("orders")
.await?
.unwrap(),
);
register_tables.insert(
"datafusion.public.lineitem".to_string(),
ctx.catalog("datafusion")
.unwrap()
.schema("public")
.unwrap()
.table("lineitem")
.await?
.unwrap(),
);
register_tables.insert(
"datafusion.public.part".to_string(),
ctx.catalog("datafusion")
.unwrap()
.schema("public")
.unwrap()
.table("part")
.await?
.unwrap(),
);
register_tables.insert(
"datafusion.public.supplier".to_string(),
ctx.catalog("datafusion")
.unwrap()
.schema("public")
.unwrap()
.table("supplier")
.await?
.unwrap(),
);
register_tables.insert(
"datafusion.public.partsupp".to_string(),
ctx.catalog("datafusion")
.unwrap()
.schema("public")
.unwrap()
.table("partsupp")
.await?
.unwrap(),
);
register_tables.insert(
"datafusion.public.nation".to_string(),
ctx.catalog("datafusion")
.unwrap()
.schema("public")
.unwrap()
.table("nation")
.await?
.unwrap(),
);
register_tables.insert(
"datafusion.public.region".to_string(),
ctx.catalog("datafusion")
.unwrap()
.schema("public")
.unwrap()
.table("region")
.await?
.unwrap(),
);

let analyzed_mdl = Arc::new(AnalyzedWrenMDL::analyze_with_tables(
manifest,
register_tables,
)?);
// TODO: there're some conflicts for datafusion optimization rules.
// let ctx = create_ctx_with_mdl(ctx, Arc::clone(&analyzed_mdl)).await?;
Ok((ctx.to_owned(), analyzed_mdl))
Expand Down
7 changes: 7 additions & 0 deletions wren-modeling-rs/sqllogictest/test_files/model.slt
Original file line number Diff line number Diff line change
Expand Up @@ -46,3 +46,10 @@ query IR
select "Id", "Price" from "Order_items" where "Order_id" in (SELECT "Order_id" FROM "Orders" WHERE "Customer_id" = 'f6c39f83de772dd502809cee2fee4c41')
----
105 287.4

# TODO: DataFusion has some case sensitivity issue with the outer reference column name
# Test the query with outer reference column
# query I
# select "Customer_id" from wrenai.public."Orders" where not exists (select 1 from wrenai.public."Order_items" where "Orders"."Order_id" = "Order_items"."Order_id")
# ----
# 1
Loading

0 comments on commit ed135a2

Please sign in to comment.