Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Fix incorrect predicate pushdown for predicates referring to right-join key columns #21293

Merged
merged 4 commits into from
Feb 17, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions crates/polars-ops/src/frame/join/args.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ pub type ChunkJoinOptIds = Vec<NullableIdxSize>;
#[cfg(not(feature = "chunked_ids"))]
pub type ChunkJoinIds = Vec<IdxSize>;

use once_cell::sync::Lazy;
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
use strum_macros::IntoStaticStr;
Expand Down Expand Up @@ -138,8 +137,8 @@ impl JoinArgs {
}

pub fn suffix(&self) -> &PlSmallStr {
static DEFAULT: Lazy<PlSmallStr> = Lazy::new(|| PlSmallStr::from_static("_right"));
self.suffix.as_ref().unwrap_or(&*DEFAULT)
const DEFAULT: &PlSmallStr = &PlSmallStr::from_static("_right");
self.suffix.as_ref().unwrap_or(DEFAULT)
}
}

Expand Down
91 changes: 12 additions & 79 deletions crates/polars-plan/src/plans/optimizer/collapse_joins.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,84 +10,11 @@ use polars_core::schema::*;
use polars_ops::frame::{IEJoinOptions, InequalityOperator};
use polars_ops::frame::{JoinCoalesce, JoinType, MaintainOrderJoin};
use polars_utils::arena::{Arena, Node};
use polars_utils::pl_str::PlSmallStr;

use super::{aexpr_to_leaf_names_iter, AExpr, ExprOrigin, JoinOptions, IR};
use crate::dsl::{JoinTypeOptionsIR, Operator};
use crate::plans::visitor::{AexprNode, RewriteRecursion, RewritingVisitor, TreeWalker};
use crate::plans::{ExprIR, MintermIter, OutputName};

fn remove_suffix<'a>(
Copy link
Collaborator Author

@nameexhaustion nameexhaustion Feb 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code moved to join_utils below for re-use

exprs: &mut Vec<ExprIR>,
expr_arena: &mut Arena<AExpr>,
schema: &'a SchemaRef,
suffix: &'a str,
) {
let mut remover = RemoveSuffix {
schema: schema.as_ref(),
suffix,
};

for expr in exprs {
// Using AexprNode::rewrite() ensures we do not mutate any nodes in-place. The nodes may be
// used in other locations and mutating them will cause really confusing bugs, such as
// https://github.com/pola-rs/polars/issues/20831.
match AexprNode::new(expr.node()).rewrite(&mut remover, expr_arena) {
Ok(v) => {
expr.set_node(v.node());

if let OutputName::ColumnLhs(colname) = expr.output_name_inner() {
if colname.ends_with(suffix) && !schema.contains(colname.as_str()) {
let name = PlSmallStr::from(&colname[..colname.len() - suffix.len()]);
expr.set_columnlhs(name);
}
}
},
e @ Err(_) => panic!("should not have failed: {:?}", e),
}
}
}

struct RemoveSuffix<'a> {
schema: &'a Schema,
suffix: &'a str,
}

impl RewritingVisitor for RemoveSuffix<'_> {
type Node = AexprNode;
type Arena = Arena<AExpr>;

fn pre_visit(
&mut self,
node: &Self::Node,
arena: &mut Self::Arena,
) -> polars_core::prelude::PolarsResult<crate::prelude::visitor::RewriteRecursion> {
let AExpr::Column(colname) = arena.get(node.node()) else {
return Ok(RewriteRecursion::NoMutateAndContinue);
};

if !colname.ends_with(self.suffix) || self.schema.contains(colname.as_str()) {
return Ok(RewriteRecursion::NoMutateAndContinue);
}

Ok(RewriteRecursion::MutateAndContinue)
}

fn mutate(
&mut self,
node: Self::Node,
arena: &mut Self::Arena,
) -> polars_core::prelude::PolarsResult<Self::Node> {
let AExpr::Column(colname) = arena.get(node.node()) else {
unreachable!();
};

// Safety: Checked in pre_visit()
Ok(AexprNode::new(arena.add(AExpr::Column(PlSmallStr::from(
&colname[..colname.len() - self.suffix.len()],
)))))
}
}
use crate::plans::optimizer::join_utils::remove_suffix;
use crate::plans::{ExprIR, MintermIter};

fn and_expr(left: Node, right: Node, expr_arena: &mut Arena<AExpr>) -> Node {
expr_arena.add(AExpr::BinaryExpr {
Expand Down Expand Up @@ -195,14 +122,16 @@ pub fn optimize(root: Node, lp_arena: &mut Arena<IR>, expr_arena: &mut Arena<AEx
left_schema,
right_schema,
suffix.as_str(),
);
)
.unwrap();
let right_origin = ExprOrigin::get_expr_origin(
right,
expr_arena,
left_schema,
right_schema,
suffix.as_str(),
);
)
.unwrap();

use ExprOrigin as EO;

Expand Down Expand Up @@ -282,12 +211,16 @@ pub fn optimize(root: Node, lp_arena: &mut Arena<IR>, expr_arena: &mut Arena<AEx
let mut can_simplify_join = false;

if !eq_left_on.is_empty() {
remove_suffix(&mut eq_right_on, expr_arena, right_schema, suffix.as_str());
for expr in eq_right_on.iter_mut() {
remove_suffix(expr, expr_arena, right_schema, suffix.as_str());
}
can_simplify_join = true;
} else {
#[cfg(feature = "iejoin")]
if !ie_op.is_empty() {
remove_suffix(&mut ie_right_on, expr_arena, right_schema, suffix.as_str());
for expr in ie_right_on.iter_mut() {
remove_suffix(expr, expr_arena, right_schema, suffix.as_str());
}
can_simplify_join = true;
}
can_simplify_join |= options.args.how.is_cross();
Expand Down
145 changes: 109 additions & 36 deletions crates/polars-plan/src/plans/optimizer/join_utils.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,18 @@
use polars_core::error::{polars_bail, PolarsResult};
use polars_core::schema::*;
use polars_utils::arena::{Arena, Node};
use polars_utils::pl_str::PlSmallStr;

use super::{aexpr_to_leaf_names_iter, AExpr};
use crate::plans::visitor::{AexprNode, RewriteRecursion, RewritingVisitor, TreeWalker};
use crate::plans::{ExprIR, OutputName};

/// Join origin of an expression
#[derive(Debug, Clone, PartialEq, Copy)]
#[repr(u8)]
pub(crate) enum ExprOrigin {
// Note: There is a merge() function implemented on this enum that relies
// on this exact u8 repr layout.
// Note: BitOr is implemented on this struct that relies on this exact u8
// repr layout (i.e. treated as a bitfield).
//
/// Utilizes no columns
None = 0b00,
Expand All @@ -21,52 +25,121 @@ pub(crate) enum ExprOrigin {
}

impl ExprOrigin {
/// Errors with ColumnNotFound if a column cannot be found on either side.
pub(crate) fn get_expr_origin(
root: Node,
expr_arena: &Arena<AExpr>,
left_schema: &SchemaRef,
right_schema: &SchemaRef,
left_schema: &Schema,
right_schema: &Schema,
suffix: &str,
) -> ExprOrigin {
let mut expr_origin = ExprOrigin::None;

for name in aexpr_to_leaf_names_iter(root, expr_arena) {
let in_left = left_schema.contains(name.as_str());
let in_right = right_schema.contains(name.as_str());
let has_suffix = name.as_str().ends_with(suffix);
let in_right = in_right
| (has_suffix
&& right_schema.contains(&name.as_str()[..name.len() - suffix.len()]));

let name_origin = match (in_left, in_right, has_suffix) {
(true, false, _) | (true, true, false) => ExprOrigin::Left,
(false, true, _) | (true, true, true) => ExprOrigin::Right,
(false, false, _) => {
unreachable!("Invalid filter column should have been filtered before")
},
};

use ExprOrigin as O;
expr_origin = match (expr_origin, name_origin) {
(O::None, other) | (other, O::None) => other,
(O::Left, O::Left) => O::Left,
(O::Right, O::Right) => O::Right,
_ => O::Both,
};
}
) -> PolarsResult<ExprOrigin> {
aexpr_to_leaf_names_iter(root, expr_arena).try_fold(
ExprOrigin::None,
|acc_origin, column_name| {
Ok(acc_origin
| Self::get_column_origin(&column_name, left_schema, right_schema, suffix)?)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Simplified to use BitOr

},
)
}

expr_origin
/// Errors with ColumnNotFound if a column cannot be found on either side.
pub(crate) fn get_column_origin(
column_name: &str,
left_schema: &Schema,
right_schema: &Schema,
suffix: &str,
) -> PolarsResult<ExprOrigin> {
Ok(if left_schema.contains(column_name) {
ExprOrigin::Left
} else if right_schema.contains(column_name)
|| column_name
.strip_suffix(suffix)
.is_some_and(|x| right_schema.contains(x))
{
ExprOrigin::Right
} else {
polars_bail!(ColumnNotFound: "{}", column_name)
})
}
}

/// Logical OR with another [`ExprOrigin`]
fn merge(&mut self, other: Self) {
*self = unsafe { std::mem::transmute::<u8, ExprOrigin>(*self as u8 | other as u8) }
impl std::ops::BitOr for ExprOrigin {
type Output = ExprOrigin;

fn bitor(self, rhs: Self) -> Self::Output {
unsafe { std::mem::transmute::<u8, ExprOrigin>(self as u8 | rhs as u8) }
}
}

impl std::ops::BitOrAssign for ExprOrigin {
fn bitor_assign(&mut self, rhs: Self) {
self.merge(rhs)
*self = *self | rhs;
}
}

pub(super) fn remove_suffix<'a>(
expr: &mut ExprIR,
expr_arena: &mut Arena<AExpr>,
schema_rhs: &'a Schema,
suffix: &'a str,
) {
let schema = schema_rhs;
// Using AexprNode::rewrite() ensures we do not mutate any nodes in-place. The nodes may be
// used in other locations and mutating them will cause really confusing bugs, such as
// https://github.com/pola-rs/polars/issues/20831.
let node = AexprNode::new(expr.node())
.rewrite(&mut RemoveSuffix { schema, suffix }, expr_arena)
.unwrap()
.node();

expr.set_node(node);

if let OutputName::ColumnLhs(colname) = expr.output_name_inner() {
if colname.ends_with(suffix) && !schema.contains(colname.as_str()) {
let name = PlSmallStr::from(&colname[..colname.len() - suffix.len()]);
expr.set_columnlhs(name);
}
}

struct RemoveSuffix<'a> {
schema: &'a Schema,
suffix: &'a str,
}

impl RewritingVisitor for RemoveSuffix<'_> {
type Node = AexprNode;
type Arena = Arena<AExpr>;

fn pre_visit(
&mut self,
node: &Self::Node,
arena: &mut Self::Arena,
) -> polars_core::prelude::PolarsResult<crate::prelude::visitor::RewriteRecursion> {
let AExpr::Column(colname) = arena.get(node.node()) else {
return Ok(RewriteRecursion::NoMutateAndContinue);
};

if !colname.ends_with(self.suffix) || self.schema.contains(colname.as_str()) {
return Ok(RewriteRecursion::NoMutateAndContinue);
}

Ok(RewriteRecursion::MutateAndContinue)
}

fn mutate(
&mut self,
node: Self::Node,
arena: &mut Self::Arena,
) -> polars_core::prelude::PolarsResult<Self::Node> {
let AExpr::Column(colname) = arena.get(node.node()) else {
unreachable!();
};

// Safety: Checked in pre_visit()
Ok(AexprNode::new(arena.add(AExpr::Column(PlSmallStr::from(
&colname[..colname.len() - self.suffix.len()],
)))))
}
}
}

Expand Down
Loading
Loading