2323//! pipeline-friendly ones. To achieve the second goal, it selects the proper
2424//! `PartitionMode` and the build side using the available statistics for hash joins.
2525
26- use std:: sync:: Arc ;
27-
2826use crate :: config:: ConfigOptions ;
2927use crate :: error:: Result ;
3028use crate :: physical_plan:: joins:: utils:: { ColumnIndex , JoinFilter } ;
@@ -34,6 +32,7 @@ use crate::physical_plan::joins::{
3432} ;
3533use crate :: physical_plan:: projection:: ProjectionExec ;
3634use crate :: physical_plan:: { ExecutionPlan , ExecutionPlanProperties } ;
35+ use std:: sync:: Arc ;
3736
3837use arrow_schema:: Schema ;
3938use datafusion_common:: tree_node:: { Transformed , TransformedResult , TreeNode } ;
@@ -1173,6 +1172,65 @@ mod tests_statistical {
11731172 ) ;
11741173 }
11751174
1175+ #[ rstest(
1176+ join_type, projection, small_on_right,
1177+ case:: inner( JoinType :: Inner , vec![ 1 ] , true ) ,
1178+ case:: left( JoinType :: Left , vec![ 1 ] , true ) ,
1179+ case:: right( JoinType :: Right , vec![ 1 ] , true ) ,
1180+ case:: full( JoinType :: Full , vec![ 1 ] , true ) ,
1181+ case:: left_anti( JoinType :: LeftAnti , vec![ 0 ] , false ) ,
1182+ case:: left_semi( JoinType :: LeftSemi , vec![ 0 ] , false ) ,
1183+ case:: right_anti( JoinType :: RightAnti , vec![ 0 ] , true ) ,
1184+ case:: right_semi( JoinType :: RightSemi , vec![ 0 ] , true ) ,
1185+ ) ]
1186+ #[ tokio:: test]
1187+ async fn test_hash_join_swap_on_joins_with_projections (
1188+ join_type : JoinType ,
1189+ projection : Vec < usize > ,
1190+ small_on_right : bool ,
1191+ ) -> Result < ( ) > {
1192+ let ( big, small) = create_big_and_small ( ) ;
1193+
1194+ let left = if small_on_right { & big } else { & small } ;
1195+ let right = if small_on_right { & small } else { & big } ;
1196+
1197+ let left_on = if small_on_right {
1198+ "big_col"
1199+ } else {
1200+ "small_col"
1201+ } ;
1202+ let right_on = if small_on_right {
1203+ "small_col"
1204+ } else {
1205+ "big_col"
1206+ } ;
1207+
1208+ let join = Arc :: new ( HashJoinExec :: try_new (
1209+ Arc :: clone ( left) ,
1210+ Arc :: clone ( right) ,
1211+ vec ! [ (
1212+ Arc :: new( Column :: new_with_schema( left_on, & left. schema( ) ) ?) ,
1213+ Arc :: new( Column :: new_with_schema( right_on, & right. schema( ) ) ?) ,
1214+ ) ] ,
1215+ None ,
1216+ & join_type,
1217+ Some ( projection) ,
1218+ PartitionMode :: Partitioned ,
1219+ false ,
1220+ ) ?) ;
1221+
1222+ let swapped = swap_hash_join ( & join. clone ( ) , PartitionMode :: Partitioned )
1223+ . expect ( "swap_hash_join must support joins with projections" ) ;
1224+ let swapped_join = swapped. as_any ( ) . downcast_ref :: < HashJoinExec > ( ) . expect (
1225+ "ProjectionExec won't be added above if HashJoinExec contains embedded projection" ,
1226+ ) ;
1227+
1228+ assert_eq ! ( swapped_join. projection, Some ( vec![ 0_usize ] ) ) ;
1229+ assert_eq ! ( swapped. schema( ) . fields. len( ) , 1 ) ;
1230+ assert_eq ! ( swapped. schema( ) . fields[ 0 ] . name( ) , "small_col" ) ;
1231+ Ok ( ( ) )
1232+ }
1233+
11761234 #[ rstest(
11771235 join_type,
11781236 case:: inner( JoinType :: Inner ) ,
0 commit comments