From 5e9c5205c82b3102335e8709db0eaeee89ced8b5 Mon Sep 17 00:00:00 2001 From: universalmind303 Date: Wed, 6 Nov 2024 17:56:42 -0600 Subject: [PATCH 1/9] fix/feat: sql left join & except --- src/daft-sql/src/planner.rs | 120 ++++++++++++++++++++++++++++++++++-- tests/sql/test_joins.py | 81 +++++++++++++++++------- 2 files changed, 175 insertions(+), 26 deletions(-) diff --git a/src/daft-sql/src/planner.rs b/src/daft-sql/src/planner.rs index 239be4845d..42ff30abfa 100644 --- a/src/daft-sql/src/planner.rs +++ b/src/daft-sql/src/planner.rs @@ -204,15 +204,119 @@ impl SQLPlanner { let selection = match query.body.as_ref() { SetExpr::Select(selection) => selection, SetExpr::Query(_) => unsupported_sql_err!("Subqueries are not supported"), - SetExpr::SetOperation { .. } => { - unsupported_sql_err!("Set operations are not supported") + SetExpr::SetOperation { + op, + set_quantifier, + left, + right, + } => { + use sqlparser::ast::{SetOperator, SetQuantifier}; + fn make_query(set_expr: &SetExpr) -> Query { + Query { + with: None, + body: Box::new(set_expr.clone()), + order_by: None, + limit: None, + limit_by: vec![], + offset: None, + fetch: None, + locks: vec![], + for_clause: None, + settings: None, + format_clause: None, + } + } + match (op, set_quantifier) { + // UNION ALL + (SetOperator::Union, SetQuantifier::All) => { + let left = make_query(left); + let right = make_query(right); + + let left = self.plan_query(&left)?; + let right = self.plan_query(&right)?; + return left.concat(&right).map_err(|e| e.into()); + } + (SetOperator::Union, SetQuantifier::Distinct) => { + unsupported_sql_err!("UNION DISTINCT is not supported.") + } + (SetOperator::Union, SetQuantifier::ByName) => { + unsupported_sql_err!("UNION BY NAME is not supported") + } + (SetOperator::Union, SetQuantifier::AllByName) => { + unsupported_sql_err!("UNION ALL BY NAME is not supported") + } + (SetOperator::Union, SetQuantifier::DistinctByName) => { + unsupported_sql_err!("UNION DISTINCT BY NAME is not supported.") + } + (SetOperator::Union, SetQuantifier::None) => { + let left = make_query(left); + let right = make_query(right); + + let left = self.plan_query(&left)?; + let right = self.plan_query(&right)?; + return left.concat(&right)?.distinct().map_err(|e| e.into()); + } + (SetOperator::Except, SetQuantifier::None) => { + let left = make_query(left); + let right = make_query(right); + let left = self.plan_query(&left)?; + let right = self.plan_query(&right)?; + let left_schema = left.schema(); + let right_schema = right.schema(); + if left_schema != right_schema { + invalid_operation_err!("EXCEPT queries must have the same schema") + } + if left_schema.is_empty() { + invalid_operation_err!("EXCEPT queries must have at least one column") + } + let left_on = left_schema + .names() + .into_iter() + .map(|n| col(n.as_ref())) + .collect::>(); + + let right_on = right_schema + .names() + .into_iter() + .map(|n| col(n.as_ref()).alias(format!("right.{}", n))) + .collect::>(); + + let Some(Expr::Alias(_, alias)) = right_on.first().map(|e| e.as_ref()) + else { + unreachable!("we know right_on has at least one element") + }; + + let first_from_right = col(alias.as_ref()); + + let joined = left.join_with_null_safe_equal( + right, + left_on.clone(), + right_on, + None, + JoinType::Left, + None, + None, + None, + )?; + + return joined + .filter(first_from_right.is_null()) + .and_then(|plan| plan.select(left_on)) + .map_err(|e| e.into()); + } + (SetOperator::Except, _) => { + unsupported_sql_err!("EXCEPT is not supported") + } + (SetOperator::Intersect, _) => { + unsupported_sql_err!("INTERSECT is not supported. Use INNER JOIN instead") + } + } } SetExpr::Values(..) => unsupported_sql_err!("VALUES are not supported"), SetExpr::Insert(..) => unsupported_sql_err!("INSERT is not supported"), SetExpr::Update(..) => unsupported_sql_err!("UPDATE is not supported"), SetExpr::Table(..) => unsupported_sql_err!("TABLE is not supported"), }; - check_select_features(selection)?; if let Some(with) = &query.with { @@ -579,9 +683,15 @@ impl SQLPlanner { // switch left/right operands if the caller has them in reverse if &left_rel.get_name() == tbl_b || &right_rel.get_name() == tbl_a { - Ok((vec![col(col_b.as_ref())], vec![col(col_a.as_ref())])) + Ok(( + vec![col(col_b.as_ref()).alias(format!("{tbl_b}.{col_b}",))], + vec![col(col_a.as_ref())], + )) } else { - Ok((vec![col(col_a.as_ref())], vec![col(col_b.as_ref())])) + Ok(( + vec![col(col_a.as_ref())], + vec![col(col_b.as_ref()).alias(format!("{tbl_b}.{col_b}",))], + )) } } else { unsupported_sql_err!("collect_compound_identifiers: Expected left.len() == 2 && right.len() == 2, but found left.len() == {:?}, right.len() == {:?}", left.len(), right.len()); diff --git a/tests/sql/test_joins.py b/tests/sql/test_joins.py index 48d7001df5..5f7d6c41f3 100644 --- a/tests/sql/test_joins.py +++ b/tests/sql/test_joins.py @@ -1,5 +1,4 @@ import daft -from daft import col from daft.sql import SQLCatalog @@ -19,11 +18,13 @@ def test_joins_with_alias(): df1 = daft.from_pydict({"idx": [1, 2], "val": [10, 20]}) df2 = daft.from_pydict({"idx": [1, 2], "score": [0.1, 0.2]}) - df_sql = daft.sql("select * from df1 as foo join df2 as bar on (foo.idx=bar.idx) where bar.score>0.1") + catalog = SQLCatalog({"df1": df1, "df2": df2}) + + df_sql = daft.sql("select * from df1 as foo join df2 as bar on foo.idx=bar.idx where bar.score>0.1", catalog) actual = df_sql.collect().to_pydict() - expected = df1.join(df2, on="idx").filter(col("score") > 0.1).collect().to_pydict() + expected = {"idx": [2], "val": [20], "bar.idx": [2], "score": [0.2]} assert actual == expected @@ -47,31 +48,24 @@ def test_joins_with_wildcard_expansion(): df2 = daft.from_pydict({"idx": [3], "score": [0.1]}) df3 = daft.from_pydict({"idx": [1], "score": [0.1], "a": [1], "b": [2], "c": [3]}) + catalog = SQLCatalog({"df1": df1, "df2": df2, "df3": df3}) + df_sql = ( - daft.sql(""" + daft.sql( + """ select df3.* from df1 left join df2 on (df1.idx=df2.idx) left join df3 on (df1.idx=df3.idx) - """) - .collect() - .to_pydict() - ) - - expected = ( - df1.join(df2, on="idx", how="left") - .join(df3, on="idx", how="left") - .select( - "idx", - col("right.score").alias("score"), - col("a"), - col("b"), - col("c"), + """, + catalog, ) .collect() .to_pydict() ) + expected = {"idx": [1, None], "score": [0.1, None], "a": [1, None], "b": [2, None], "c": [3, None]} + assert df_sql == expected # make sure it works with exclusion patterns too @@ -86,9 +80,54 @@ def test_joins_with_wildcard_expansion(): .to_pydict() ) + expected = {"idx": [1, None], "score": [0.1, None]} + + assert df_sql == expected + + +def test_joins_with_duplicate_columns(): + table1 = daft.from_pydict({"id": [1, 2, 3, 4], "value": ["a", "b", "c", "d"]}) + + table2 = daft.from_pydict({"id": [2, 3, 4, 5], "value": ["b", "c", "d", "e"]}) + + catalog = SQLCatalog({"table1": table1, "table2": table2}) + + actual = daft.sql( + """ + SELECT * + FROM table1 t1 + LEFT JOIN table2 t2 on t2.id = t1.id; + """, + catalog, + ).collect() + expected = { - "idx": [1, 2], - "score": [0.1, None], + "id": [1, 2, 3, 4], + "value": ["a", "b", "c", "d"], + "t2.id": [None, 2, 3, 4], + "t2.value": [None, "b", "c", "d"], } - assert df_sql == expected + assert actual.to_pydict() == expected + + +def test_except(): + table1 = daft.from_pydict({"id": [1, 2, 3, 4], "value": ["a", "b", "c", "d"]}) + table2 = daft.from_pydict({"id": [2, 3, 4, 5], "value": ["b", "c", "d", "e"]}) + + catalog = SQLCatalog({"table1": table1, "table2": table2}) + + actual = ( + daft.sql( + """ + SELECT * from table1 t1 EXCEPT select * from table2 t2 + """, + catalog, + ) + .collect() + .to_pydict() + ) + + expected = {"id": [1], "value": ["a"]} + + assert actual == expected From c93e22c31c66e2891cbaf42381014e1abbe55250 Mon Sep 17 00:00:00 2001 From: universalmind303 Date: Wed, 6 Nov 2024 18:16:10 -0600 Subject: [PATCH 2/9] fix/feat: sql left join & except --- src/daft-sql/src/planner.rs | 23 ++++------------------- 1 file changed, 4 insertions(+), 19 deletions(-) diff --git a/src/daft-sql/src/planner.rs b/src/daft-sql/src/planner.rs index 42ff30abfa..6d0a232755 100644 --- a/src/daft-sql/src/planner.rs +++ b/src/daft-sql/src/planner.rs @@ -256,6 +256,7 @@ impl SQLPlanner { let right = self.plan_query(&right)?; return left.concat(&right)?.distinct().map_err(|e| e.into()); } + // We use an ANTI join on all columns to implement EXCEPT (SetOperator::Except, SetQuantifier::None) => { let left = make_query(left); let right = make_query(right); @@ -275,34 +276,18 @@ impl SQLPlanner { .map(|n| col(n.as_ref())) .collect::>(); - let right_on = right_schema - .names() - .into_iter() - .map(|n| col(n.as_ref()).alias(format!("right.{}", n))) - .collect::>(); - - let Some(Expr::Alias(_, alias)) = right_on.first().map(|e| e.as_ref()) - else { - unreachable!("we know right_on has at least one element") - }; - - let first_from_right = col(alias.as_ref()); - let joined = left.join_with_null_safe_equal( right, left_on.clone(), - right_on, + left_on, None, - JoinType::Left, + JoinType::Anti, None, None, None, )?; - return joined - .filter(first_from_right.is_null()) - .and_then(|plan| plan.select(left_on)) - .map_err(|e| e.into()); + return Ok(joined); } (SetOperator::Except, _) => { unsupported_sql_err!("EXCEPT is not supported") From 133eb958719ba8f3e56c4e9ddbbb341b97501d4e Mon Sep 17 00:00:00 2001 From: universalmind303 Date: Wed, 6 Nov 2024 18:29:10 -0600 Subject: [PATCH 3/9] fix borked tests --- src/daft-sql/src/lib.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/daft-sql/src/lib.rs b/src/daft-sql/src/lib.rs index 238077efaf..06135b27a8 100644 --- a/src/daft-sql/src/lib.rs +++ b/src/daft-sql/src/lib.rs @@ -249,7 +249,7 @@ mod tests { } #[rstest( - null_equals_null => [false, true] + null_equals_null => [false] )] fn test_join( mut planner: SQLPlanner, @@ -266,12 +266,12 @@ mod tests { .join_with_null_safe_equal( tbl_3, vec![col("id")], - vec![col("id")], + vec![col("id").alias("tbl3.id")], Some(vec![null_equals_null]), JoinType::Inner, None, None, - None, + Some("tbl3."), )? .select(vec![col("*")])? .build(); From 96cb913f9d72b0d089234f02db1b72bc85c50976 Mon Sep 17 00:00:00 2001 From: universalmind303 Date: Wed, 6 Nov 2024 18:29:30 -0600 Subject: [PATCH 4/9] fix borked tests --- src/daft-sql/src/lib.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/daft-sql/src/lib.rs b/src/daft-sql/src/lib.rs index 06135b27a8..cdb9b28d8b 100644 --- a/src/daft-sql/src/lib.rs +++ b/src/daft-sql/src/lib.rs @@ -249,7 +249,7 @@ mod tests { } #[rstest( - null_equals_null => [false] + null_equals_null => [false, true] )] fn test_join( mut planner: SQLPlanner, From 3941bdfe001f25703b076ee14abd58080bba2a80 Mon Sep 17 00:00:00 2001 From: universalmind303 Date: Thu, 7 Nov 2024 11:03:41 -0600 Subject: [PATCH 5/9] remove set ops from pr --- src/daft-sql/src/planner.rs | 92 +------------------------------------ tests/sql/test_joins.py | 22 --------- 2 files changed, 2 insertions(+), 112 deletions(-) diff --git a/src/daft-sql/src/planner.rs b/src/daft-sql/src/planner.rs index 6d0a232755..f57aaf96af 100644 --- a/src/daft-sql/src/planner.rs +++ b/src/daft-sql/src/planner.rs @@ -205,97 +205,9 @@ impl SQLPlanner { SetExpr::Select(selection) => selection, SetExpr::Query(_) => unsupported_sql_err!("Subqueries are not supported"), SetExpr::SetOperation { - op, - set_quantifier, - left, - right, + op, set_quantifier, .. } => { - use sqlparser::ast::{SetOperator, SetQuantifier}; - fn make_query(set_expr: &SetExpr) -> Query { - Query { - with: None, - body: Box::new(set_expr.clone()), - order_by: None, - limit: None, - limit_by: vec![], - offset: None, - fetch: None, - locks: vec![], - for_clause: None, - settings: None, - format_clause: None, - } - } - match (op, set_quantifier) { - // UNION ALL - (SetOperator::Union, SetQuantifier::All) => { - let left = make_query(left); - let right = make_query(right); - - let left = self.plan_query(&left)?; - let right = self.plan_query(&right)?; - return left.concat(&right).map_err(|e| e.into()); - } - (SetOperator::Union, SetQuantifier::Distinct) => { - unsupported_sql_err!("UNION DISTINCT is not supported.") - } - (SetOperator::Union, SetQuantifier::ByName) => { - unsupported_sql_err!("UNION BY NAME is not supported") - } - (SetOperator::Union, SetQuantifier::AllByName) => { - unsupported_sql_err!("UNION ALL BY NAME is not supported") - } - (SetOperator::Union, SetQuantifier::DistinctByName) => { - unsupported_sql_err!("UNION DISTINCT BY NAME is not supported.") - } - (SetOperator::Union, SetQuantifier::None) => { - let left = make_query(left); - let right = make_query(right); - - let left = self.plan_query(&left)?; - let right = self.plan_query(&right)?; - return left.concat(&right)?.distinct().map_err(|e| e.into()); - } - // We use an ANTI join on all columns to implement EXCEPT - (SetOperator::Except, SetQuantifier::None) => { - let left = make_query(left); - let right = make_query(right); - let left = self.plan_query(&left)?; - let right = self.plan_query(&right)?; - let left_schema = left.schema(); - let right_schema = right.schema(); - if left_schema != right_schema { - invalid_operation_err!("EXCEPT queries must have the same schema") - } - if left_schema.is_empty() { - invalid_operation_err!("EXCEPT queries must have at least one column") - } - let left_on = left_schema - .names() - .into_iter() - .map(|n| col(n.as_ref())) - .collect::>(); - - let joined = left.join_with_null_safe_equal( - right, - left_on.clone(), - left_on, - None, - JoinType::Anti, - None, - None, - None, - )?; - - return Ok(joined); - } - (SetOperator::Except, _) => { - unsupported_sql_err!("EXCEPT is not supported") - } - (SetOperator::Intersect, _) => { - unsupported_sql_err!("INTERSECT is not supported. Use INNER JOIN instead") - } - } + unsupported_sql_err!("{op} {set_quantifier} is not supported.",) } SetExpr::Values(..) => unsupported_sql_err!("VALUES are not supported"), SetExpr::Insert(..) => unsupported_sql_err!("INSERT is not supported"), diff --git a/tests/sql/test_joins.py b/tests/sql/test_joins.py index 5f7d6c41f3..10872906d4 100644 --- a/tests/sql/test_joins.py +++ b/tests/sql/test_joins.py @@ -109,25 +109,3 @@ def test_joins_with_duplicate_columns(): } assert actual.to_pydict() == expected - - -def test_except(): - table1 = daft.from_pydict({"id": [1, 2, 3, 4], "value": ["a", "b", "c", "d"]}) - table2 = daft.from_pydict({"id": [2, 3, 4, 5], "value": ["b", "c", "d", "e"]}) - - catalog = SQLCatalog({"table1": table1, "table2": table2}) - - actual = ( - daft.sql( - """ - SELECT * from table1 t1 EXCEPT select * from table2 t2 - """, - catalog, - ) - .collect() - .to_pydict() - ) - - expected = {"id": [1], "value": ["a"]} - - assert actual == expected From 4b4615ca99404151590231dbac31fc63cd3035fe Mon Sep 17 00:00:00 2001 From: universalmind303 Date: Fri, 8 Nov 2024 12:04:09 -0600 Subject: [PATCH 6/9] better handling for join keys --- src/daft-logical-plan/src/builder.rs | 1102 +++++++++++++++++ src/daft-plan/src/builder.rs | 6 + src/daft-plan/src/display.rs | 2 + src/daft-plan/src/logical_ops/join.rs | 32 +- .../rules/eliminate_cross_join.rs | 7 + .../rules/push_down_filter.rs | 8 + src/daft-plan/src/logical_plan.rs | 3 +- .../src/physical_planner/translate.rs | 1 + src/daft-sql/src/lib.rs | 1 + src/daft-sql/src/planner.rs | 136 +- 10 files changed, 1188 insertions(+), 110 deletions(-) create mode 100644 src/daft-logical-plan/src/builder.rs diff --git a/src/daft-logical-plan/src/builder.rs b/src/daft-logical-plan/src/builder.rs new file mode 100644 index 0000000000..4a8d9d2cfc --- /dev/null +++ b/src/daft-logical-plan/src/builder.rs @@ -0,0 +1,1102 @@ +use std::{ + collections::{BTreeMap, HashMap, HashSet}, + sync::Arc, +}; + +use common_daft_config::DaftPlanningConfig; +use common_display::mermaid::MermaidDisplayOptions; +use common_error::{DaftError, DaftResult}; +use common_file_formats::{FileFormat, FileFormatConfig, ParquetSourceConfig}; +use common_io_config::IOConfig; +use daft_core::{ + join::{JoinStrategy, JoinType}, + prelude::TimeUnit, +}; +use daft_dsl::{col, ExprRef}; +use daft_scan::{ + glob::GlobScanOperator, + storage_config::{NativeStorageConfig, StorageConfig}, + PhysicalScanInfo, Pushdowns, ScanOperatorRef, +}; +use daft_schema::{ + field::Field, + schema::{Schema, SchemaRef}, +}; +#[cfg(feature = "python")] +use { + crate::sink_info::{CatalogInfo, IcebergCatalogInfo}, + crate::source_info::InMemoryInfo, + common_daft_config::PyDaftPlanningConfig, + daft_dsl::python::PyExpr, + daft_scan::python::pylib::ScanOperatorHandle, + daft_schema::python::schema::PySchema, + pyo3::prelude::*, +}; + +use crate::{ + logical_plan::LogicalPlan, + ops, + optimization::{Optimizer, OptimizerConfig}, + partitioning::{ + HashRepartitionConfig, IntoPartitionsConfig, RandomShuffleConfig, RepartitionSpec, + }, + sink_info::{OutputFileInfo, SinkInfo}, + source_info::SourceInfo, + LogicalPlanRef, +}; + +/// A logical plan builder, which simplifies constructing logical plans via +/// a fluent interface. E.g., LogicalPlanBuilder::table_scan(..).project(..).filter(..).build(). +/// +/// This builder holds the current root (sink) of the logical plan, and the building methods return +/// a brand new builder holding a new plan; i.e., this is an immutable builder. +#[derive(Debug, Clone)] +pub struct LogicalPlanBuilder { + // The current root of the logical plan in this builder. + pub plan: Arc, + config: Option>, +} + +impl LogicalPlanBuilder { + pub fn new(plan: Arc, config: Option>) -> Self { + Self { plan, config } + } +} + +impl From<&Self> for LogicalPlanBuilder { + fn from(builder: &Self) -> Self { + Self { + plan: builder.plan.clone(), + config: builder.config.clone(), + } + } +} + +impl From for LogicalPlanRef { + fn from(value: LogicalPlanBuilder) -> Self { + value.plan + } +} + +impl From<&LogicalPlanBuilder> for LogicalPlanRef { + fn from(value: &LogicalPlanBuilder) -> Self { + value.plan.clone() + } +} + +impl From for LogicalPlanBuilder { + fn from(plan: LogicalPlanRef) -> Self { + Self::new(plan, None) + } +} + +pub trait IntoGlobPath { + fn into_glob_path(self) -> Vec; +} +impl IntoGlobPath for Vec { + fn into_glob_path(self) -> Vec { + self + } +} +impl IntoGlobPath for String { + fn into_glob_path(self) -> Vec { + vec![self] + } +} +impl IntoGlobPath for &str { + fn into_glob_path(self) -> Vec { + vec![self.to_string()] + } +} +impl IntoGlobPath for Vec<&str> { + fn into_glob_path(self) -> Vec { + self.iter().map(|s| (*s).to_string()).collect() + } +} +impl LogicalPlanBuilder { + /// Replace the LogicalPlanBuilder's plan with the provided plan + pub fn with_new_plan>>(&self, plan: LP) -> Self { + Self::new(plan.into(), self.config.clone()) + } + + /// Parametrize the LogicalPlanBuilder with a DaftPlanningConfig + pub fn with_config(&self, config: Arc) -> Self { + Self::new(self.plan.clone(), Some(config)) + } + + #[cfg(feature = "python")] + pub fn in_memory_scan( + partition_key: &str, + cache_entry: PyObject, + schema: Arc, + num_partitions: usize, + size_bytes: usize, + num_rows: usize, + ) -> DaftResult { + let source_info = SourceInfo::InMemory(InMemoryInfo::new( + schema.clone(), + partition_key.into(), + cache_entry, + num_partitions, + size_bytes, + num_rows, + None, // TODO(sammy) thread through clustering spec to Python + )); + let logical_plan: LogicalPlan = ops::Source::new(schema, source_info.into()).into(); + + Ok(Self::new(logical_plan.into(), None)) + } + + #[cfg(feature = "python")] + pub fn delta_scan>( + glob_path: T, + io_config: Option, + multithreaded_io: bool, + ) -> DaftResult { + use daft_scan::storage_config::{NativeStorageConfig, PyStorageConfig, StorageConfig}; + + Python::with_gil(|py| { + let io_config = io_config.unwrap_or_default(); + + let native_storage_config = NativeStorageConfig { + io_config: Some(io_config), + multithreaded_io, + }; + + let py_storage_config: PyStorageConfig = + Arc::new(StorageConfig::Native(Arc::new(native_storage_config))).into(); + + // let py_io_config = PyIOConfig { config: io_config }; + let delta_lake_scan = PyModule::import_bound(py, "daft.delta_lake.delta_lake_scan")?; + let delta_lake_scan_operator = + delta_lake_scan.getattr(pyo3::intern!(py, "DeltaLakeScanOperator"))?; + let delta_lake_operator = delta_lake_scan_operator + .call1((glob_path.as_ref(), py_storage_config))? + .to_object(py); + let scan_operator_handle = + ScanOperatorHandle::from_python_scan_operator(delta_lake_operator, py)?; + Self::table_scan(scan_operator_handle.into(), None) + }) + } + + #[cfg(not(feature = "python"))] + pub fn delta_scan( + glob_path: T, + io_config: Option, + multithreaded_io: bool, + ) -> DaftResult { + panic!("Delta Lake scan requires the 'python' feature to be enabled.") + } + + pub fn table_scan( + scan_operator: ScanOperatorRef, + pushdowns: Option, + ) -> DaftResult { + let schema = scan_operator.0.schema(); + let partitioning_keys = scan_operator.0.partitioning_keys(); + let source_info = SourceInfo::Physical(PhysicalScanInfo::new( + scan_operator.clone(), + schema.clone(), + partitioning_keys.into(), + pushdowns.clone().unwrap_or_default(), + )); + // If file path column is specified, check that it doesn't conflict with any column names in the schema. + if let Some(file_path_column) = &scan_operator.0.file_path_column() { + if schema.names().contains(&(*file_path_column).to_string()) { + return Err(DaftError::ValueError(format!( + "Attempting to make a Schema with a file path column name that already exists: {}", + file_path_column + ))); + } + } + // Add generated fields to the schema. + let schema_with_generated_fields = { + if let Some(generated_fields) = scan_operator.0.generated_fields() { + // We use the non-distinct union here because some scan operators have table schema information that + // already contain partitioned fields. For example,the deltalake scan operator takes the table schema. + Arc::new(schema.non_distinct_union(&generated_fields)) + } else { + schema + } + }; + // If column selection (projection) pushdown is specified, prune unselected columns from the schema. + let output_schema = if let Some(Pushdowns { + columns: Some(columns), + .. + }) = &pushdowns + && columns.len() < schema_with_generated_fields.fields.len() + { + let pruned_upstream_schema = schema_with_generated_fields + .fields + .iter() + .filter(|&(name, _)| columns.contains(name)) + .map(|(_, field)| field.clone()) + .collect::>(); + Arc::new(Schema::new(pruned_upstream_schema)?) + } else { + schema_with_generated_fields + }; + let logical_plan: LogicalPlan = ops::Source::new(output_schema, source_info.into()).into(); + Ok(Self::new(logical_plan.into(), None)) + } + + pub fn parquet_scan(glob_path: T) -> ParquetScanBuilder { + ParquetScanBuilder::new(glob_path) + } + + pub fn select(&self, to_select: Vec) -> DaftResult { + let logical_plan: LogicalPlan = ops::Project::try_new(self.plan.clone(), to_select)?.into(); + Ok(self.with_new_plan(logical_plan)) + } + + pub fn with_columns(&self, columns: Vec) -> DaftResult { + let fields = &self.schema().fields; + let current_col_names = fields + .iter() + .map(|(name, _)| name.as_str()) + .collect::>(); + let new_col_name_and_exprs = columns + .iter() + .map(|e| (e.name(), e.clone())) + .collect::>(); + + let mut exprs = fields + .iter() + .map(|(name, _)| { + new_col_name_and_exprs + .get(name.as_str()) + .cloned() + .unwrap_or_else(|| col(name.clone())) + }) + .collect::>(); + + exprs.extend( + columns + .iter() + .filter(|e| !current_col_names.contains(e.name())) + .cloned(), + ); + + let logical_plan: LogicalPlan = ops::Project::try_new(self.plan.clone(), exprs)?.into(); + Ok(self.with_new_plan(logical_plan)) + } + + pub fn exclude(&self, to_exclude: Vec) -> DaftResult { + let to_exclude = HashSet::<_>::from_iter(to_exclude.iter()); + + let exprs = self + .schema() + .fields + .iter() + .filter_map(|(name, _)| { + if to_exclude.contains(name) { + None + } else { + Some(col(name.clone())) + } + }) + .collect::>(); + + let logical_plan: LogicalPlan = ops::Project::try_new(self.plan.clone(), exprs)?.into(); + Ok(self.with_new_plan(logical_plan)) + } + + pub fn filter(&self, predicate: ExprRef) -> DaftResult { + let logical_plan: LogicalPlan = ops::Filter::try_new(self.plan.clone(), predicate)?.into(); + Ok(self.with_new_plan(logical_plan)) + } + + pub fn limit(&self, limit: i64, eager: bool) -> DaftResult { + let logical_plan: LogicalPlan = ops::Limit::new(self.plan.clone(), limit, eager).into(); + Ok(self.with_new_plan(logical_plan)) + } + + pub fn explode(&self, to_explode: Vec) -> DaftResult { + let logical_plan: LogicalPlan = + ops::Explode::try_new(self.plan.clone(), to_explode)?.into(); + Ok(self.with_new_plan(logical_plan)) + } + + pub fn unpivot( + &self, + ids: Vec, + values: Vec, + variable_name: &str, + value_name: &str, + ) -> DaftResult { + let values = if values.is_empty() { + let ids_set = HashSet::<_>::from_iter(ids.iter()); + + self.schema() + .fields + .iter() + .filter_map(|(name, _)| { + let column = col(name.clone()); + + if ids_set.contains(&column) { + None + } else { + Some(column) + } + }) + .collect() + } else { + values + }; + + let logical_plan: LogicalPlan = + ops::Unpivot::try_new(self.plan.clone(), ids, values, variable_name, value_name)? + .into(); + Ok(self.with_new_plan(logical_plan)) + } + + pub fn sort(&self, sort_by: Vec, descending: Vec) -> DaftResult { + let logical_plan: LogicalPlan = + ops::Sort::try_new(self.plan.clone(), sort_by, descending)?.into(); + Ok(self.with_new_plan(logical_plan)) + } + + pub fn hash_repartition( + &self, + num_partitions: Option, + partition_by: Vec, + ) -> DaftResult { + let logical_plan: LogicalPlan = ops::Repartition::try_new( + self.plan.clone(), + RepartitionSpec::Hash(HashRepartitionConfig::new(num_partitions, partition_by)), + )? + .into(); + Ok(self.with_new_plan(logical_plan)) + } + + pub fn random_shuffle(&self, num_partitions: Option) -> DaftResult { + let logical_plan: LogicalPlan = ops::Repartition::try_new( + self.plan.clone(), + RepartitionSpec::Random(RandomShuffleConfig::new(num_partitions)), + )? + .into(); + Ok(self.with_new_plan(logical_plan)) + } + + pub fn into_partitions(&self, num_partitions: usize) -> DaftResult { + let logical_plan: LogicalPlan = ops::Repartition::try_new( + self.plan.clone(), + RepartitionSpec::IntoPartitions(IntoPartitionsConfig::new(num_partitions)), + )? + .into(); + Ok(self.with_new_plan(logical_plan)) + } + + pub fn distinct(&self) -> DaftResult { + let logical_plan: LogicalPlan = ops::Distinct::new(self.plan.clone()).into(); + Ok(self.with_new_plan(logical_plan)) + } + + pub fn sample( + &self, + fraction: f64, + with_replacement: bool, + seed: Option, + ) -> DaftResult { + let logical_plan: LogicalPlan = + ops::Sample::new(self.plan.clone(), fraction, with_replacement, seed).into(); + Ok(self.with_new_plan(logical_plan)) + } + + pub fn aggregate( + &self, + agg_exprs: Vec, + groupby_exprs: Vec, + ) -> DaftResult { + let logical_plan: LogicalPlan = + ops::Aggregate::try_new(self.plan.clone(), agg_exprs, groupby_exprs)?.into(); + Ok(self.with_new_plan(logical_plan)) + } + + pub fn pivot( + &self, + group_by: Vec, + pivot_column: ExprRef, + value_column: ExprRef, + agg_expr: ExprRef, + names: Vec, + ) -> DaftResult { + let pivot_logical_plan: LogicalPlan = ops::Pivot::try_new( + self.plan.clone(), + group_by, + pivot_column, + value_column, + agg_expr, + names, + )? + .into(); + Ok(self.with_new_plan(pivot_logical_plan)) + } + + #[allow(clippy::too_many_arguments)] + pub fn join>( + &self, + right: Right, + left_on: Vec, + right_on: Vec, + join_type: JoinType, + join_strategy: Option, + join_suffix: Option<&str>, + join_prefix: Option<&str>, + ) -> DaftResult { + self.join_with_null_safe_equal( + right, + left_on, + right_on, + None, + join_type, + join_strategy, + join_suffix, + join_prefix, + ) + } + + #[allow(clippy::too_many_arguments)] + pub fn join_with_null_safe_equal>( + &self, + right: Right, + left_on: Vec, + right_on: Vec, + null_equals_nulls: Option>, + join_type: JoinType, + join_strategy: Option, + join_suffix: Option<&str>, + join_prefix: Option<&str>, + ) -> DaftResult { + let logical_plan: LogicalPlan = ops::Join::try_new( + self.plan.clone(), + right.into(), + left_on, + right_on, + null_equals_nulls, + join_type, + join_strategy, + join_suffix, + join_prefix, + )? + .into(); + Ok(self.with_new_plan(logical_plan)) + } + + pub fn cross_join>( + &self, + right: Right, + join_suffix: Option<&str>, + join_prefix: Option<&str>, + ) -> DaftResult { + self.join( + right, + vec![], + vec![], + JoinType::Inner, + None, + join_suffix, + join_prefix, + ) + } + + pub fn concat(&self, other: &Self) -> DaftResult { + let logical_plan: LogicalPlan = + ops::Concat::try_new(self.plan.clone(), other.plan.clone())?.into(); + Ok(self.with_new_plan(logical_plan)) + } + + pub fn add_monotonically_increasing_id(&self, column_name: Option<&str>) -> DaftResult { + let logical_plan: LogicalPlan = + ops::MonotonicallyIncreasingId::new(self.plan.clone(), column_name).into(); + Ok(self.with_new_plan(logical_plan)) + } + + pub fn table_write( + &self, + root_dir: &str, + file_format: FileFormat, + partition_cols: Option>, + compression: Option, + io_config: Option, + ) -> DaftResult { + let sink_info = SinkInfo::OutputFileInfo(OutputFileInfo::new( + root_dir.into(), + file_format, + partition_cols, + compression, + io_config, + )); + + let logical_plan: LogicalPlan = + ops::Sink::try_new(self.plan.clone(), sink_info.into())?.into(); + Ok(self.with_new_plan(logical_plan)) + } + + #[cfg(feature = "python")] + #[allow(clippy::too_many_arguments)] + pub fn iceberg_write( + &self, + table_name: String, + table_location: String, + partition_spec_id: i64, + partition_cols: Vec, + iceberg_schema: PyObject, + iceberg_properties: PyObject, + io_config: Option, + catalog_columns: Vec, + ) -> DaftResult { + let sink_info = SinkInfo::CatalogInfo(CatalogInfo { + catalog: crate::sink_info::CatalogType::Iceberg(IcebergCatalogInfo { + table_name, + table_location, + partition_spec_id, + partition_cols, + iceberg_schema, + iceberg_properties, + io_config, + }), + catalog_columns, + }); + + let logical_plan: LogicalPlan = + ops::Sink::try_new(self.plan.clone(), sink_info.into())?.into(); + Ok(self.with_new_plan(logical_plan)) + } + + #[cfg(feature = "python")] + #[allow(clippy::too_many_arguments)] + pub fn delta_write( + &self, + path: String, + columns_name: Vec, + mode: String, + version: i32, + large_dtypes: bool, + partition_cols: Option>, + io_config: Option, + ) -> DaftResult { + use crate::sink_info::DeltaLakeCatalogInfo; + let sink_info = SinkInfo::CatalogInfo(CatalogInfo { + catalog: crate::sink_info::CatalogType::DeltaLake(DeltaLakeCatalogInfo { + path, + mode, + version, + large_dtypes, + partition_cols, + io_config, + }), + catalog_columns: columns_name, + }); + + let logical_plan: LogicalPlan = + ops::Sink::try_new(self.plan.clone(), sink_info.into())?.into(); + Ok(self.with_new_plan(logical_plan)) + } + + #[cfg(feature = "python")] + #[allow(clippy::too_many_arguments)] + pub fn lance_write( + &self, + path: String, + columns_name: Vec, + mode: String, + io_config: Option, + kwargs: PyObject, + ) -> DaftResult { + use crate::sink_info::LanceCatalogInfo; + + let sink_info = SinkInfo::CatalogInfo(CatalogInfo { + catalog: crate::sink_info::CatalogType::Lance(LanceCatalogInfo { + path, + mode, + io_config, + kwargs, + }), + catalog_columns: columns_name, + }); + + let logical_plan: LogicalPlan = + ops::Sink::try_new(self.plan.clone(), sink_info.into())?.into(); + Ok(self.with_new_plan(logical_plan)) + } + + pub fn build(&self) -> Arc { + self.plan.clone() + } + + pub fn schema(&self) -> SchemaRef { + self.plan.schema() + } + + pub fn repr_ascii(&self, simple: bool) -> String { + self.plan.repr_ascii(simple) + } + + pub fn repr_mermaid(&self, opts: MermaidDisplayOptions) -> String { + use common_display::mermaid::MermaidDisplay; + self.plan.repr_mermaid(opts) + } +} + +pub struct ParquetScanBuilder { + pub glob_paths: Vec, + pub infer_schema: bool, + pub coerce_int96_timestamp_unit: TimeUnit, + pub field_id_mapping: Option>>, + pub row_groups: Option>>>, + pub chunk_size: Option, + pub io_config: Option, + pub multithreaded: bool, + pub schema: Option, + pub file_path_column: Option, + pub hive_partitioning: bool, +} + +impl ParquetScanBuilder { + pub fn new(glob_paths: T) -> Self { + let glob_paths = glob_paths.into_glob_path(); + Self::new_impl(glob_paths) + } + + // concrete implementation to reduce LLVM code duplication + fn new_impl(glob_paths: Vec) -> Self { + Self { + glob_paths, + infer_schema: true, + coerce_int96_timestamp_unit: TimeUnit::Nanoseconds, + field_id_mapping: None, + row_groups: None, + chunk_size: None, + multithreaded: true, + schema: None, + io_config: None, + file_path_column: None, + hive_partitioning: false, + } + } + pub fn infer_schema(mut self, infer_schema: bool) -> Self { + self.infer_schema = infer_schema; + self + } + pub fn coerce_int96_timestamp_unit(mut self, unit: TimeUnit) -> Self { + self.coerce_int96_timestamp_unit = unit; + self + } + pub fn field_id_mapping(mut self, field_id_mapping: Arc>) -> Self { + self.field_id_mapping = Some(field_id_mapping); + self + } + pub fn row_groups(mut self, row_groups: Vec>>) -> Self { + self.row_groups = Some(row_groups); + self + } + pub fn chunk_size(mut self, chunk_size: usize) -> Self { + self.chunk_size = Some(chunk_size); + self + } + + pub fn io_config(mut self, io_config: IOConfig) -> Self { + self.io_config = Some(io_config); + self + } + + pub fn multithreaded(mut self, multithreaded: bool) -> Self { + self.multithreaded = multithreaded; + self + } + + pub fn schema(mut self, schema: SchemaRef) -> Self { + self.schema = Some(schema); + self + } + + pub fn file_path_column(mut self, file_path_column: String) -> Self { + self.file_path_column = Some(file_path_column); + self + } + + pub fn hive_partitioning(mut self, hive_partitioning: bool) -> Self { + self.hive_partitioning = hive_partitioning; + self + } + + pub fn finish(self) -> DaftResult { + let cfg = ParquetSourceConfig { + coerce_int96_timestamp_unit: self.coerce_int96_timestamp_unit, + field_id_mapping: self.field_id_mapping, + row_groups: self.row_groups, + chunk_size: self.chunk_size, + }; + + let operator = Arc::new(GlobScanOperator::try_new( + self.glob_paths, + Arc::new(FileFormatConfig::Parquet(cfg)), + Arc::new(StorageConfig::Native(Arc::new( + NativeStorageConfig::new_internal(self.multithreaded, self.io_config), + ))), + self.infer_schema, + self.schema, + self.file_path_column, + self.hive_partitioning, + )?); + + LogicalPlanBuilder::table_scan(ScanOperatorRef(operator), None) + } +} + +/// A Python-facing wrapper of the LogicalPlanBuilder. +/// +/// This lightweight proxy interface should hold as much of the Python-specific logic +/// as possible, converting pyo3 wrapper type arguments into their underlying Rust-native types +/// (e.g. PySchema -> Schema). +#[cfg_attr(feature = "python", pyclass(name = "LogicalPlanBuilder"))] +#[derive(Debug)] +pub struct PyLogicalPlanBuilder { + // Internal logical plan builder. + pub builder: LogicalPlanBuilder, +} + +impl PyLogicalPlanBuilder { + pub fn new(builder: LogicalPlanBuilder) -> Self { + Self { builder } + } +} + +#[cfg(feature = "python")] +fn pyexprs_to_exprs(vec: Vec) -> Vec { + vec.into_iter().map(|e| e.into()).collect() +} + +#[cfg(feature = "python")] +#[pymethods] +impl PyLogicalPlanBuilder { + #[staticmethod] + pub fn in_memory_scan( + partition_key: &str, + cache_entry: PyObject, + schema: PySchema, + num_partitions: usize, + size_bytes: usize, + num_rows: usize, + ) -> PyResult { + Ok(LogicalPlanBuilder::in_memory_scan( + partition_key, + cache_entry, + schema.into(), + num_partitions, + size_bytes, + num_rows, + )? + .into()) + } + + #[staticmethod] + pub fn table_scan(scan_operator: ScanOperatorHandle) -> PyResult { + Ok(LogicalPlanBuilder::table_scan(scan_operator.into(), None)?.into()) + } + + pub fn with_planning_config( + &self, + daft_planning_config: PyDaftPlanningConfig, + ) -> PyResult { + Ok(self.builder.with_config(daft_planning_config.config).into()) + } + + pub fn select(&self, to_select: Vec) -> PyResult { + Ok(self.builder.select(pyexprs_to_exprs(to_select))?.into()) + } + + pub fn with_columns(&self, columns: Vec) -> PyResult { + Ok(self.builder.with_columns(pyexprs_to_exprs(columns))?.into()) + } + + pub fn exclude(&self, to_exclude: Vec) -> PyResult { + Ok(self.builder.exclude(to_exclude)?.into()) + } + + pub fn filter(&self, predicate: PyExpr) -> PyResult { + Ok(self.builder.filter(predicate.expr)?.into()) + } + + pub fn limit(&self, limit: i64, eager: bool) -> PyResult { + Ok(self.builder.limit(limit, eager)?.into()) + } + + pub fn explode(&self, to_explode: Vec) -> PyResult { + Ok(self.builder.explode(pyexprs_to_exprs(to_explode))?.into()) + } + + pub fn unpivot( + &self, + ids: Vec, + values: Vec, + variable_name: &str, + value_name: &str, + ) -> PyResult { + let ids_exprs = ids + .iter() + .map(|e| e.clone().into()) + .collect::>(); + let values_exprs = values + .iter() + .map(|e| e.clone().into()) + .collect::>(); + Ok(self + .builder + .unpivot(ids_exprs, values_exprs, variable_name, value_name)? + .into()) + } + + pub fn sort(&self, sort_by: Vec, descending: Vec) -> PyResult { + Ok(self + .builder + .sort(pyexprs_to_exprs(sort_by), descending)? + .into()) + } + + pub fn hash_repartition( + &self, + partition_by: Vec, + num_partitions: Option, + ) -> PyResult { + Ok(self + .builder + .hash_repartition(num_partitions, pyexprs_to_exprs(partition_by))? + .into()) + } + + pub fn random_shuffle(&self, num_partitions: Option) -> PyResult { + Ok(self.builder.random_shuffle(num_partitions)?.into()) + } + + pub fn into_partitions(&self, num_partitions: usize) -> PyResult { + Ok(self.builder.into_partitions(num_partitions)?.into()) + } + + pub fn distinct(&self) -> PyResult { + Ok(self.builder.distinct()?.into()) + } + + pub fn sample( + &self, + fraction: f64, + with_replacement: bool, + seed: Option, + ) -> PyResult { + Ok(self + .builder + .sample(fraction, with_replacement, seed)? + .into()) + } + + pub fn aggregate(&self, agg_exprs: Vec, groupby_exprs: Vec) -> PyResult { + Ok(self + .builder + .aggregate(pyexprs_to_exprs(agg_exprs), pyexprs_to_exprs(groupby_exprs))? + .into()) + } + + pub fn pivot( + &self, + group_by: Vec, + pivot_column: PyExpr, + value_column: PyExpr, + agg_expr: PyExpr, + names: Vec, + ) -> PyResult { + Ok(self + .builder + .pivot( + pyexprs_to_exprs(group_by), + pivot_column.into(), + value_column.into(), + agg_expr.into(), + names, + )? + .into()) + } + #[allow(clippy::too_many_arguments)] + pub fn join( + &self, + right: &Self, + left_on: Vec, + right_on: Vec, + join_type: JoinType, + join_strategy: Option, + join_suffix: Option<&str>, + join_prefix: Option<&str>, + ) -> PyResult { + Ok(self + .builder + .join( + &right.builder, + pyexprs_to_exprs(left_on), + pyexprs_to_exprs(right_on), + join_type, + join_strategy, + join_suffix, + join_prefix, + )? + .into()) + } + + pub fn concat(&self, other: &Self) -> DaftResult { + Ok(self.builder.concat(&other.builder)?.into()) + } + + pub fn add_monotonically_increasing_id(&self, column_name: Option<&str>) -> PyResult { + Ok(self + .builder + .add_monotonically_increasing_id(column_name)? + .into()) + } + + pub fn table_write( + &self, + root_dir: &str, + file_format: FileFormat, + partition_cols: Option>, + compression: Option, + io_config: Option, + ) -> PyResult { + Ok(self + .builder + .table_write( + root_dir, + file_format, + partition_cols.map(pyexprs_to_exprs), + compression, + io_config.map(|cfg| cfg.config), + )? + .into()) + } + + #[allow(clippy::too_many_arguments)] + pub fn iceberg_write( + &self, + table_name: String, + table_location: String, + partition_spec_id: i64, + partition_cols: Vec, + iceberg_schema: PyObject, + iceberg_properties: PyObject, + catalog_columns: Vec, + io_config: Option, + ) -> PyResult { + Ok(self + .builder + .iceberg_write( + table_name, + table_location, + partition_spec_id, + pyexprs_to_exprs(partition_cols), + iceberg_schema, + iceberg_properties, + io_config.map(|cfg| cfg.config), + catalog_columns, + )? + .into()) + } + + #[allow(clippy::too_many_arguments)] + pub fn delta_write( + &self, + path: String, + columns_name: Vec, + mode: String, + version: i32, + large_dtypes: bool, + partition_cols: Option>, + io_config: Option, + ) -> PyResult { + Ok(self + .builder + .delta_write( + path, + columns_name, + mode, + version, + large_dtypes, + partition_cols, + io_config.map(|cfg| cfg.config), + )? + .into()) + } + + pub fn lance_write( + &self, + py: Python, + path: String, + columns_name: Vec, + mode: String, + io_config: Option, + kwargs: Option, + ) -> PyResult { + let kwargs = kwargs.unwrap_or_else(|| py.None()); + Ok(self + .builder + .lance_write( + path, + columns_name, + mode, + io_config.map(|cfg| cfg.config), + kwargs, + )? + .into()) + } + pub fn schema(&self) -> PyResult { + Ok(self.builder.schema().into()) + } + + /// Optimize the underlying logical plan, returning a new plan builder containing the optimized plan. + pub fn optimize(&self, py: Python) -> PyResult { + py.allow_threads(|| { + // Create optimizer + let default_optimizer_config: OptimizerConfig = Default::default(); + let optimizer_config = OptimizerConfig { enable_actor_pool_projections: self.builder.config.as_ref().map(|planning_cfg| planning_cfg.enable_actor_pool_projections).unwrap_or(default_optimizer_config.enable_actor_pool_projections), ..default_optimizer_config }; + let optimizer = Optimizer::new(optimizer_config); + + // Run LogicalPlan optimizations + let unoptimized_plan = self.builder.build(); + let optimized_plan = optimizer.optimize( + unoptimized_plan, + |new_plan, rule_batch, pass, transformed, seen| { + if transformed { + log::debug!( + "Rule batch {:?} transformed plan on pass {}, and produced {} plan:\n{}", + rule_batch, + pass, + if seen { "an already seen" } else { "a new" }, + new_plan.repr_ascii(true), + ); + } else { + log::debug!( + "Rule batch {:?} did NOT transform plan on pass {} for plan:\n{}", + rule_batch, + pass, + new_plan.repr_ascii(true), + ); + } + }, + )?; + + let builder = LogicalPlanBuilder::new(optimized_plan, self.builder.config.clone()); + Ok(builder.into()) + }) + } + + pub fn repr_ascii(&self, simple: bool) -> PyResult { + Ok(self.builder.repr_ascii(simple)) + } + + pub fn repr_mermaid(&self, opts: MermaidDisplayOptions) -> String { + self.builder.repr_mermaid(opts) + } +} + +impl From for PyLogicalPlanBuilder { + fn from(plan: LogicalPlanBuilder) -> Self { + Self::new(plan) + } +} diff --git a/src/daft-plan/src/builder.rs b/src/daft-plan/src/builder.rs index 34cdf65104..9086cce5f4 100644 --- a/src/daft-plan/src/builder.rs +++ b/src/daft-plan/src/builder.rs @@ -454,6 +454,7 @@ impl LogicalPlanBuilder { join_strategy: Option, join_suffix: Option<&str>, join_prefix: Option<&str>, + keep_join_keys: bool, ) -> DaftResult { self.join_with_null_safe_equal( right, @@ -464,6 +465,7 @@ impl LogicalPlanBuilder { join_strategy, join_suffix, join_prefix, + keep_join_keys, ) } @@ -478,6 +480,7 @@ impl LogicalPlanBuilder { join_strategy: Option, join_suffix: Option<&str>, join_prefix: Option<&str>, + keep_join_keys: bool, ) -> DaftResult { let logical_plan: LogicalPlan = logical_ops::Join::try_new( self.plan.clone(), @@ -489,6 +492,7 @@ impl LogicalPlanBuilder { join_strategy, join_suffix, join_prefix, + keep_join_keys, )? .into(); Ok(self.with_new_plan(logical_plan)) @@ -508,6 +512,7 @@ impl LogicalPlanBuilder { None, join_suffix, join_prefix, + false, // no join keys to keep ) } @@ -948,6 +953,7 @@ impl PyLogicalPlanBuilder { join_strategy, join_suffix, join_prefix, + false, )? .into()) } diff --git a/src/daft-plan/src/display.rs b/src/daft-plan/src/display.rs index 0f3228cc03..c860b23412 100644 --- a/src/daft-plan/src/display.rs +++ b/src/daft-plan/src/display.rs @@ -162,6 +162,7 @@ mod test { None, None, None, + false, )? .filter(col("first_name").eq(lit("hello")))? .select(vec![col("first_name")])? @@ -236,6 +237,7 @@ Project1 --> Limit0 None, None, None, + false, )? .filter(col("first_name").eq(lit("hello")))? .select(vec![col("first_name")])? diff --git a/src/daft-plan/src/logical_ops/join.rs b/src/daft-plan/src/logical_ops/join.rs index b3310657d1..16384fd4db 100644 --- a/src/daft-plan/src/logical_ops/join.rs +++ b/src/daft-plan/src/logical_ops/join.rs @@ -60,6 +60,11 @@ impl Join { join_strategy: Option, join_suffix: Option<&str>, join_prefix: Option<&str>, + // if true, then duplicate column names will be kept + // ex: select * from a left join b on a.id = b.id + // if true, then the resulting schema will have two columns named id (id, and b.id) + // In SQL the join column is always kept, while in dataframes it is not + keep_join_keys: bool, ) -> logical_plan::Result { let (left_on, _) = resolve_exprs(left_on, &left.schema(), false).context(CreationSnafu)?; let (right_on, _) = @@ -136,19 +141,27 @@ impl Join { let right_rename_mapping: HashMap<_, _> = right_names .iter() .filter_map(|name| { - if !names_so_far.contains(name) || common_join_keys.contains(name) { + if !names_so_far.contains(name) + || (common_join_keys.contains(name) && !keep_join_keys) + { None } else { let mut new_name = name.clone(); while names_so_far.contains(&new_name) { - if let Some(prefix) = join_prefix { - new_name = format!("{}{}", prefix, new_name); - } else if join_suffix.is_none() { - new_name = format!("right.{}", new_name); - } - if let Some(suffix) = join_suffix { - new_name = format!("{}{}", new_name, suffix); - } + new_name = match (join_prefix, join_suffix) { + (Some(prefix), Some(suffix)) => { + format!("{}{}{}", prefix, new_name, suffix) + } + (Some(prefix), None) => { + format!("{}{}", prefix, new_name) + } + (None, Some(suffix)) => { + format!("{}{}", new_name, suffix) + } + (None, None) => { + format!("right.{}", new_name) + } + }; } names_so_far.insert(new_name.clone()); @@ -253,6 +266,7 @@ impl Join { } _ => { let unique_id = Uuid::new_v4().to_string(); + let renamed_left_expr = left_expr.alias(format!("{}_{}", left_expr.name(), unique_id)); let renamed_right_expr = diff --git a/src/daft-plan/src/logical_optimization/rules/eliminate_cross_join.rs b/src/daft-plan/src/logical_optimization/rules/eliminate_cross_join.rs index 2bc6bea766..492a6720f2 100644 --- a/src/daft-plan/src/logical_optimization/rules/eliminate_cross_join.rs +++ b/src/daft-plan/src/logical_optimization/rules/eliminate_cross_join.rs @@ -524,6 +524,7 @@ mod tests { None, None, None, + false, )? .build(); @@ -554,6 +555,7 @@ mod tests { None, None, None, + false, )? .filter(col("a").eq(col("right.a")).or(col("right.b").eq(col("a"))))? .build(); @@ -588,6 +590,7 @@ mod tests { None, None, None, + false, )? .filter(expr2.and(expr4))? .build(); @@ -622,6 +625,7 @@ mod tests { None, None, None, + false, )? .filter(expr2.or(expr4))? .build(); @@ -682,6 +686,7 @@ mod tests { None, None, None, + false, )? .filter(col("t2.c").lt(lit(15u32)).or(col("t2.c").eq(lit(688u32))))? .build(); @@ -699,6 +704,7 @@ mod tests { None, None, None, + false, )? .filter( col("t4.c") @@ -724,6 +730,7 @@ mod tests { None, None, None, + false, )? .filter(col("t4.c").lt(lit(15u32)).or(col("t4.c").eq(lit(688u32))))? .build(); diff --git a/src/daft-plan/src/logical_optimization/rules/push_down_filter.rs b/src/daft-plan/src/logical_optimization/rules/push_down_filter.rs index 0000606be8..6f07d216ce 100644 --- a/src/daft-plan/src/logical_optimization/rules/push_down_filter.rs +++ b/src/daft-plan/src/logical_optimization/rules/push_down_filter.rs @@ -683,6 +683,7 @@ mod tests { None, None, None, + false, )? .filter(pred.clone())? .build(); @@ -704,6 +705,7 @@ mod tests { None, None, None, + false, )? .build(); assert_optimized_plan_eq(plan, expected)?; @@ -747,6 +749,7 @@ mod tests { None, None, None, + false, )? .filter(pred.clone())? .build(); @@ -768,6 +771,7 @@ mod tests { None, None, None, + false, )? .build(); assert_optimized_plan_eq(plan, expected)?; @@ -824,6 +828,7 @@ mod tests { None, None, None, + false, )? .filter(pred.clone())? .build(); @@ -853,6 +858,7 @@ mod tests { None, None, None, + false, )? .build(); assert_optimized_plan_eq(plan, expected)?; @@ -892,6 +898,7 @@ mod tests { None, None, None, + false, )? .filter(pred)? .build(); @@ -934,6 +941,7 @@ mod tests { None, None, None, + false, )? .filter(pred)? .build(); diff --git a/src/daft-plan/src/logical_plan.rs b/src/daft-plan/src/logical_plan.rs index 37b75217c8..287adce68f 100644 --- a/src/daft-plan/src/logical_plan.rs +++ b/src/daft-plan/src/logical_plan.rs @@ -273,7 +273,8 @@ impl LogicalPlan { *join_type, *join_strategy, None, // The suffix is already eagerly computed in the constructor - None // the prefix is already eagerly computed in the constructor + None, // the prefix is already eagerly computed in the constructor + false // this is already eagerly computed in the constructor ).unwrap()), _ => panic!("Logical op {} has one input, but got two", self), }, diff --git a/src/daft-plan/src/physical_planner/translate.rs b/src/daft-plan/src/physical_planner/translate.rs index 2bfc4a2aed..6ae0225659 100644 --- a/src/daft-plan/src/physical_planner/translate.rs +++ b/src/daft-plan/src/physical_planner/translate.rs @@ -1212,6 +1212,7 @@ mod tests { Some(JoinStrategy::Hash), None, None, + false, )? .build(); logical_to_physical(logical_plan, cfg) diff --git a/src/daft-sql/src/lib.rs b/src/daft-sql/src/lib.rs index cdb9b28d8b..dd1accbc8c 100644 --- a/src/daft-sql/src/lib.rs +++ b/src/daft-sql/src/lib.rs @@ -272,6 +272,7 @@ mod tests { None, None, Some("tbl3."), + false, )? .select(vec![col("*")])? .build(); diff --git a/src/daft-sql/src/planner.rs b/src/daft-sql/src/planner.rs index f57aaf96af..6f608842e3 100644 --- a/src/daft-sql/src/planner.rs +++ b/src/daft-sql/src/planner.rs @@ -553,15 +553,9 @@ impl SQLPlanner { self.table_map.insert(right.get_name(), right.clone()); let right_join_prefix = Some(format!("{}.", right.get_name())); - rel.inner = rel.inner.join( - right.inner, - vec![], - vec![], - JoinType::Inner, - None, - None, - right_join_prefix.as_deref(), - )?; + rel.inner = + rel.inner + .cross_join(right.inner, None, right_join_prefix.as_deref())?; } return Ok(rel); } @@ -580,15 +574,9 @@ impl SQLPlanner { // switch left/right operands if the caller has them in reverse if &left_rel.get_name() == tbl_b || &right_rel.get_name() == tbl_a { - Ok(( - vec![col(col_b.as_ref()).alias(format!("{tbl_b}.{col_b}",))], - vec![col(col_a.as_ref())], - )) + Ok((vec![col(col_b.as_ref())], vec![col(col_a.as_ref())])) } else { - Ok(( - vec![col(col_a.as_ref())], - vec![col(col_b.as_ref()).alias(format!("{tbl_b}.{col_b}",))], - )) + Ok((vec![col(col_a.as_ref())], vec![col(col_b.as_ref())])) } } else { unsupported_sql_err!("collect_compound_identifiers: Expected left.len() == 2 && right.len() == 2, but found left.len() == {:?}, right.len() == {:?}", left.len(), right.len()); @@ -643,10 +631,7 @@ impl SQLPlanner { for join in &from.joins { use sqlparser::ast::{ JoinConstraint, - JoinOperator::{ - AsOf, CrossApply, CrossJoin, FullOuter, Inner, LeftAnti, LeftOuter, LeftSemi, - OuterApply, RightAnti, RightOuter, RightSemi, - }, + JoinOperator::{FullOuter, Inner, LeftAnti, LeftOuter, LeftSemi, RightOuter}, }; let right_rel = self.plan_relation(&join.relation)?; self.table_map @@ -654,94 +639,45 @@ impl SQLPlanner { let right_rel_name = right_rel.get_name(); let right_join_prefix = Some(format!("{right_rel_name}.")); - match &join.join_operator { - Inner(JoinConstraint::On(expr)) => { + let (join_type, constraint) = match &join.join_operator { + Inner(constraint) => (JoinType::Inner, constraint), + LeftOuter(constraint) => (JoinType::Left, constraint), + RightOuter(constraint) => (JoinType::Right, constraint), + FullOuter(constraint) => (JoinType::Outer, constraint), + LeftSemi(constraint) => (JoinType::Semi, constraint), + LeftAnti(constraint) => (JoinType::Anti, constraint), + + _ => unsupported_sql_err!("Unsupported join type: {:?}", join.join_operator), + }; + + let (left_on, right_on, null_eq_null, keep_join_keys) = match &constraint { + JoinConstraint::On(expr) => { let (left_on, right_on, null_equals_nulls) = process_join_on(expr, &left_rel, &right_rel)?; - - left_rel.inner = left_rel.inner.join_with_null_safe_equal( - right_rel.inner, - left_on, - right_on, - Some(null_equals_nulls), - JoinType::Inner, - None, - None, - right_join_prefix.as_deref(), - )?; + (left_on, right_on, Some(null_equals_nulls), true) } - Inner(JoinConstraint::Using(idents)) => { + JoinConstraint::Using(idents) => { let on = idents .iter() .map(|i| col(i.value.clone())) .collect::>(); - - left_rel.inner = left_rel.inner.join( - right_rel.inner, - on.clone(), - on, - JoinType::Inner, - None, - None, - right_join_prefix.as_deref(), - )?; - } - LeftOuter(JoinConstraint::On(expr)) => { - let (left_on, right_on, null_equals_nulls) = - process_join_on(expr, &left_rel, &right_rel)?; - - left_rel.inner = left_rel.inner.join_with_null_safe_equal( - right_rel.inner, - left_on, - right_on, - Some(null_equals_nulls), - JoinType::Left, - None, - None, - right_join_prefix.as_deref(), - )?; - } - RightOuter(JoinConstraint::On(expr)) => { - let (left_on, right_on, null_equals_nulls) = - process_join_on(expr, &left_rel, &right_rel)?; - - left_rel.inner = left_rel.inner.join_with_null_safe_equal( - right_rel.inner, - left_on, - right_on, - Some(null_equals_nulls), - JoinType::Right, - None, - None, - right_join_prefix.as_deref(), - )?; - } - - FullOuter(JoinConstraint::On(expr)) => { - let (left_on, right_on, null_equals_nulls) = - process_join_on(expr, &left_rel, &right_rel)?; - - left_rel.inner = left_rel.inner.join_with_null_safe_equal( - right_rel.inner, - left_on, - right_on, - Some(null_equals_nulls), - JoinType::Outer, - None, - None, - right_join_prefix.as_deref(), - )?; + (on.clone(), on, None, false) } - CrossJoin => unsupported_sql_err!("CROSS JOIN"), - LeftSemi(_) => unsupported_sql_err!("LEFT SEMI JOIN"), - RightSemi(_) => unsupported_sql_err!("RIGHT SEMI JOIN"), - LeftAnti(_) => unsupported_sql_err!("LEFT ANTI JOIN"), - RightAnti(_) => unsupported_sql_err!("RIGHT ANTI JOIN"), - CrossApply => unsupported_sql_err!("CROSS APPLY"), - OuterApply => unsupported_sql_err!("OUTER APPLY"), - AsOf { .. } => unsupported_sql_err!("AS OF"), - join_type => unsupported_sql_err!("join type: {join_type:?}"), + JoinConstraint::Natural => unsupported_sql_err!("NATURAL JOIN not supported"), + JoinConstraint::None => unsupported_sql_err!("JOIN without ON/USING not supported"), }; + + left_rel.inner = left_rel.inner.join_with_null_safe_equal( + right_rel.inner, + left_on, + right_on, + null_eq_null, + join_type, + None, + None, + right_join_prefix.as_deref(), + keep_join_keys, + )?; } Ok(left_rel) From 3ad22f3c77b1ea5f9e55bc471123a41f4ea4e743 Mon Sep 17 00:00:00 2001 From: universalmind303 Date: Fri, 8 Nov 2024 12:11:36 -0600 Subject: [PATCH 7/9] fix bad merge --- src/daft-logical-plan/src/builder.rs | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/daft-logical-plan/src/builder.rs b/src/daft-logical-plan/src/builder.rs index 4a8d9d2cfc..b3600292e9 100644 --- a/src/daft-logical-plan/src/builder.rs +++ b/src/daft-logical-plan/src/builder.rs @@ -443,6 +443,7 @@ impl LogicalPlanBuilder { join_strategy: Option, join_suffix: Option<&str>, join_prefix: Option<&str>, + keep_join_keys: bool, ) -> DaftResult { self.join_with_null_safe_equal( right, @@ -453,6 +454,7 @@ impl LogicalPlanBuilder { join_strategy, join_suffix, join_prefix, + keep_join_keys, ) } @@ -467,6 +469,7 @@ impl LogicalPlanBuilder { join_strategy: Option, join_suffix: Option<&str>, join_prefix: Option<&str>, + keep_join_keys: bool, ) -> DaftResult { let logical_plan: LogicalPlan = ops::Join::try_new( self.plan.clone(), @@ -478,6 +481,7 @@ impl LogicalPlanBuilder { join_strategy, join_suffix, join_prefix, + keep_join_keys, )? .into(); Ok(self.with_new_plan(logical_plan)) @@ -497,6 +501,7 @@ impl LogicalPlanBuilder { None, join_suffix, join_prefix, + false, // no join keys to keep ) } @@ -937,6 +942,7 @@ impl PyLogicalPlanBuilder { join_strategy, join_suffix, join_prefix, + false, // dataframes do not keep the join keys when joining )? .into()) } From d88edec299fec4819f568828dcda1538d97d4afc Mon Sep 17 00:00:00 2001 From: universalmind303 Date: Fri, 8 Nov 2024 12:57:44 -0600 Subject: [PATCH 8/9] fix more join bugs --- src/daft-sql/src/planner.rs | 29 +++++++++++++++++++++++++++-- tests/sql/test_joins.py | 16 ++++++++++++++++ 2 files changed, 43 insertions(+), 2 deletions(-) diff --git a/src/daft-sql/src/planner.rs b/src/daft-sql/src/planner.rs index 7fbd68454a..4613f7e139 100644 --- a/src/daft-sql/src/planner.rs +++ b/src/daft-sql/src/planner.rs @@ -653,8 +653,33 @@ impl SQLPlanner { let null_equals_null = *op == BinaryOperator::Spaceship; collect_compound_identifiers(left, right, left_rel, right_rel) .map(|(left, right)| (left, right, vec![null_equals_null])) + } else if let ( + sqlparser::ast::Expr::Identifier(left), + sqlparser::ast::Expr::Identifier(right), + ) = (left.as_ref(), right.as_ref()) + { + let left = ident_to_str(left); + let right = ident_to_str(right); + + // we don't know which table the identifiers belong to, so we need to check both + let left_schema = left_rel.schema(); + let right_schema = right_rel.schema(); + + // if the left side is in the left schema, then we assume the right side is in the right schema + let (left_on, right_on) = if left_schema.get_field(&left).is_ok() { + (col(left), col(right)) + // if the right side is in the left schema, then we assume the left side is in the right schema + } else if right_schema.get_field(&left).is_ok() { + (col(right), col(left)) + } else { + unsupported_sql_err!("JOIN clauses must reference columns in the joined tables; found `{}`", left); + }; + + let null_equals_null = *op == BinaryOperator::Spaceship; + + Ok((vec![left_on], vec![right_on], vec![null_equals_null])) } else { - unsupported_sql_err!("JOIN clauses support '='/'<=>' constraints on identifiers; found lhs={:?}, rhs={:?}", left, right); + unsupported_sql_err!("JOIN clauses support '='/'<=>' constraints on identifiers; found `{left} {op} {right}`"); } } BinaryOperator::And => { @@ -668,7 +693,7 @@ impl SQLPlanner { Ok((left_i, right_i, null_equals_nulls_i)) } _ => { - unsupported_sql_err!("JOIN clauses support '=' constraints combined with 'AND'; found op = '{:?}'", op); + unsupported_sql_err!("JOIN clauses support '=' constraints combined with 'AND'; found op = '{}'", op); } } } else if let sqlparser::ast::Expr::Nested(expr) = expression { diff --git a/tests/sql/test_joins.py b/tests/sql/test_joins.py index 10872906d4..3914d43f00 100644 --- a/tests/sql/test_joins.py +++ b/tests/sql/test_joins.py @@ -1,3 +1,5 @@ +import pytest + import daft from daft.sql import SQLCatalog @@ -109,3 +111,17 @@ def test_joins_with_duplicate_columns(): } assert actual.to_pydict() == expected + + +@pytest.mark.parametrize("join_condition", ["idx=idax", "idax=idx"]) +def test_joins_without_compound_ident(join_condition): + df1 = daft.from_pydict({"idx": [1, None], "val": [10, 20]}) + df2 = daft.from_pydict({"idax": [1, None], "score": [0.1, 0.2]}) + + catalog = SQLCatalog({"df1": df1, "df2": df2}) + + df_sql = daft.sql(f"select * from df1 join df2 on {join_condition}", catalog).to_pydict() + + expected = {"idx": [1], "val": [10], "idax": [1], "score": [0.1]} + + assert df_sql == expected From ebf964458b42a743ad770ff42fd05d82ff87ce09 Mon Sep 17 00:00:00 2001 From: universalmind303 Date: Fri, 8 Nov 2024 12:59:37 -0600 Subject: [PATCH 9/9] fix borked tests --- src/daft-sql/src/lib.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/daft-sql/src/lib.rs b/src/daft-sql/src/lib.rs index 6ccd8d5aad..fcd348f02c 100644 --- a/src/daft-sql/src/lib.rs +++ b/src/daft-sql/src/lib.rs @@ -266,13 +266,13 @@ mod tests { .join_with_null_safe_equal( tbl_3, vec![col("id")], - vec![col("id").alias("tbl3.id")], + vec![col("id")], Some(vec![null_equals_null]), JoinType::Inner, None, None, Some("tbl3."), - false, + true, )? .select(vec![col("*")])? .build();