Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEAT]: Sql joins with duplicate cols #3241

6 changes: 6 additions & 0 deletions src/daft-logical-plan/src/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -443,6 +443,7 @@ impl LogicalPlanBuilder {
join_strategy: Option<JoinStrategy>,
join_suffix: Option<&str>,
join_prefix: Option<&str>,
keep_join_keys: bool,
) -> DaftResult<Self> {
self.join_with_null_safe_equal(
right,
Expand All @@ -453,6 +454,7 @@ impl LogicalPlanBuilder {
join_strategy,
join_suffix,
join_prefix,
keep_join_keys,
)
}

Expand All @@ -467,6 +469,7 @@ impl LogicalPlanBuilder {
join_strategy: Option<JoinStrategy>,
join_suffix: Option<&str>,
join_prefix: Option<&str>,
keep_join_keys: bool,
) -> DaftResult<Self> {
let logical_plan: LogicalPlan = ops::Join::try_new(
self.plan.clone(),
Expand All @@ -478,6 +481,7 @@ impl LogicalPlanBuilder {
join_strategy,
join_suffix,
join_prefix,
keep_join_keys,
)?
.into();
Ok(self.with_new_plan(logical_plan))
Expand All @@ -497,6 +501,7 @@ impl LogicalPlanBuilder {
None,
join_suffix,
join_prefix,
false, // no join keys to keep
)
}

Expand Down Expand Up @@ -937,6 +942,7 @@ impl PyLogicalPlanBuilder {
join_strategy,
join_suffix,
join_prefix,
false, // dataframes do not keep the join keys when joining
)?
.into())
}
Expand Down
2 changes: 2 additions & 0 deletions src/daft-logical-plan/src/display.rs
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ mod test {
None,
None,
None,
false,
)?
.filter(col("first_name").eq(lit("hello")))?
.select(vec![col("first_name")])?
Expand Down Expand Up @@ -185,6 +186,7 @@ Project1 --> Limit0
None,
None,
None,
false,
)?
.filter(col("first_name").eq(lit("hello")))?
.select(vec![col("first_name")])?
Expand Down
3 changes: 2 additions & 1 deletion src/daft-logical-plan/src/logical_plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,8 @@ impl LogicalPlan {
*join_type,
*join_strategy,
None, // The suffix is already eagerly computed in the constructor
None // the prefix is already eagerly computed in the constructor
None, // the prefix is already eagerly computed in the constructor
false // this is already eagerly computed in the constructor
).unwrap()),
_ => panic!("Logical op {} has one input, but got two", self),
},
Expand Down
32 changes: 23 additions & 9 deletions src/daft-logical-plan/src/ops/join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,11 @@ impl Join {
join_strategy: Option<JoinStrategy>,
join_suffix: Option<&str>,
join_prefix: Option<&str>,
// if true, then duplicate column names will be kept
// ex: select * from a left join b on a.id = b.id
// if true, then the resulting schema will have two columns named id (id, and b.id)
// In SQL the join column is always kept, while in dataframes it is not
keep_join_keys: bool,
) -> logical_plan::Result<Self> {
let (left_on, _) = resolve_exprs(left_on, &left.schema(), false).context(CreationSnafu)?;
let (right_on, _) =
Expand Down Expand Up @@ -136,19 +141,27 @@ impl Join {
let right_rename_mapping: HashMap<_, _> = right_names
.iter()
.filter_map(|name| {
if !names_so_far.contains(name) || common_join_keys.contains(name) {
if !names_so_far.contains(name)
|| (common_join_keys.contains(name) && !keep_join_keys)
{
None
} else {
let mut new_name = name.clone();
while names_so_far.contains(&new_name) {
if let Some(prefix) = join_prefix {
new_name = format!("{}{}", prefix, new_name);
} else if join_suffix.is_none() {
new_name = format!("right.{}", new_name);
}
if let Some(suffix) = join_suffix {
new_name = format!("{}{}", new_name, suffix);
}
new_name = match (join_prefix, join_suffix) {
(Some(prefix), Some(suffix)) => {
format!("{}{}{}", prefix, new_name, suffix)
}
(Some(prefix), None) => {
format!("{}{}", prefix, new_name)
}
(None, Some(suffix)) => {
format!("{}{}", new_name, suffix)
}
(None, None) => {
format!("right.{}", new_name)
}
};
}
names_so_far.insert(new_name.clone());

Expand Down Expand Up @@ -253,6 +266,7 @@ impl Join {
}
_ => {
let unique_id = Uuid::new_v4().to_string();

let renamed_left_expr =
left_expr.alias(format!("{}_{}", left_expr.name(), unique_id));
let renamed_right_expr =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -524,6 +524,7 @@ mod tests {
None,
None,
None,
false,
)?
.build();

Expand Down Expand Up @@ -554,6 +555,7 @@ mod tests {
None,
None,
None,
false,
)?
.filter(col("a").eq(col("right.a")).or(col("right.b").eq(col("a"))))?
.build();
Expand Down Expand Up @@ -588,6 +590,7 @@ mod tests {
None,
None,
None,
false,
)?
.filter(expr2.and(expr4))?
.build();
Expand Down Expand Up @@ -622,6 +625,7 @@ mod tests {
None,
None,
None,
false,
)?
.filter(expr2.or(expr4))?
.build();
Expand Down Expand Up @@ -682,6 +686,7 @@ mod tests {
None,
None,
None,
false,
)?
.filter(col("t2.c").lt(lit(15u32)).or(col("t2.c").eq(lit(688u32))))?
.build();
Expand All @@ -699,6 +704,7 @@ mod tests {
None,
None,
None,
false,
)?
.filter(
col("t4.c")
Expand All @@ -724,6 +730,7 @@ mod tests {
None,
None,
None,
false,
)?
.filter(col("t4.c").lt(lit(15u32)).or(col("t4.c").eq(lit(688u32))))?
.build();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -683,6 +683,7 @@ mod tests {
None,
None,
None,
false,
)?
.filter(pred.clone())?
.build();
Expand All @@ -704,6 +705,7 @@ mod tests {
None,
None,
None,
false,
)?
.build();
assert_optimized_plan_eq(plan, expected)?;
Expand Down Expand Up @@ -747,6 +749,7 @@ mod tests {
None,
None,
None,
false,
)?
.filter(pred.clone())?
.build();
Expand All @@ -768,6 +771,7 @@ mod tests {
None,
None,
None,
false,
)?
.build();
assert_optimized_plan_eq(plan, expected)?;
Expand Down Expand Up @@ -824,6 +828,7 @@ mod tests {
None,
None,
None,
false,
)?
.filter(pred.clone())?
.build();
Expand Down Expand Up @@ -853,6 +858,7 @@ mod tests {
None,
None,
None,
false,
)?
.build();
assert_optimized_plan_eq(plan, expected)?;
Expand Down Expand Up @@ -892,6 +898,7 @@ mod tests {
None,
None,
None,
false,
)?
.filter(pred)?
.build();
Expand Down Expand Up @@ -934,6 +941,7 @@ mod tests {
None,
None,
None,
false,
)?
.filter(pred)?
.build();
Expand Down
1 change: 1 addition & 0 deletions src/daft-physical-plan/src/physical_planner/translate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1210,6 +1210,7 @@ mod tests {
Some(JoinStrategy::Hash),
None,
None,
false,
)?
.build();
logical_to_physical(logical_plan, cfg)
Expand Down
3 changes: 2 additions & 1 deletion src/daft-sql/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,8 @@ mod tests {
JoinType::Inner,
None,
None,
None,
Some("tbl3."),
true,
)?
.select(vec![col("*")])?
.build();
Expand Down
Loading
Loading