@@ -77,8 +77,26 @@ use arrow_buffer::BooleanBuffer;
7777use datafusion_expr:: Operator ;
7878use datafusion_physical_expr_common:: datum:: compare_op_for_nested;
7979use futures:: { ready, Stream , StreamExt , TryStreamExt } ;
80+ use log:: debug;
8081use parking_lot:: Mutex ;
8182
83+ pub const RANDOM_STATE : RandomState = RandomState :: with_seeds ( 0 , 0 , 0 , 0 ) ;
84+
85+ #[ derive( Default ) ]
86+ pub struct JoinContext {
87+ build_state : Mutex < Option < Arc < JoinLeftData > > > ,
88+ }
89+
90+ impl JoinContext {
91+ pub fn set_build_state ( & self , state : Arc < JoinLeftData > ) {
92+ self . build_state . lock ( ) . replace ( state) ;
93+ }
94+
95+ pub fn get_build_state ( & self ) -> Option < Arc < JoinLeftData > > {
96+ self . build_state . lock ( ) . clone ( )
97+ }
98+ }
99+
82100pub struct SharedJoinState {
83101 state_impl : Arc < dyn SharedJoinStateImpl > ,
84102}
@@ -128,7 +146,7 @@ pub trait SharedJoinStateImpl: Send + Sync + 'static {
128146type SharedBitmapBuilder = Mutex < BooleanBufferBuilder > ;
129147
130148/// HashTable and input data for the left (build side) of a join
131- struct JoinLeftData {
149+ pub struct JoinLeftData {
132150 /// The hash table with indices into `batch`
133151 hash_map : JoinHashMap ,
134152 /// The input rows for the build side
@@ -165,6 +183,10 @@ impl JoinLeftData {
165183 }
166184 }
167185
186+ pub fn contains_hash ( & self , hash : u64 ) -> bool {
187+ self . hash_map . contains_hash ( hash)
188+ }
189+
168190 /// return a reference to the hash map
169191 fn hash_map ( & self ) -> & JoinHashMap {
170192 & self . hash_map
@@ -768,6 +790,7 @@ impl ExecutionPlan for HashJoinExec {
768790
769791 let distributed_state =
770792 context. session_config ( ) . get_extension :: < SharedJoinState > ( ) ;
793+ let join_context = context. session_config ( ) . get_extension :: < JoinContext > ( ) ;
771794
772795 let join_metrics = BuildProbeJoinMetrics :: new ( partition, & self . metrics ) ;
773796 let left_fut = match self . mode {
@@ -855,6 +878,7 @@ impl ExecutionPlan for HashJoinExec {
855878 batch_size,
856879 hashes_buffer : vec ! [ ] ,
857880 right_side_ordered : self . right . output_ordering ( ) . is_some ( ) ,
881+ join_context,
858882 } ) )
859883 }
860884
@@ -1187,6 +1211,7 @@ struct HashJoinStream {
11871211 hashes_buffer : Vec < u64 > ,
11881212 /// Specifies whether the right side has an ordering to potentially preserve
11891213 right_side_ordered : bool ,
1214+ join_context : Option < Arc < JoinContext > > ,
11901215}
11911216
11921217impl RecordBatchStream for HashJoinStream {
@@ -1399,6 +1424,11 @@ impl HashJoinStream {
13991424 . get_shared( cx) ) ?;
14001425 build_timer. done ( ) ;
14011426
1427+ if let Some ( ctx) = self . join_context . as_ref ( ) {
1428+ debug ! ( "setting join left data in join context" ) ;
1429+ ctx. set_build_state ( Arc :: clone ( & left_data) ) ;
1430+ }
1431+
14021432 self . state = HashJoinStreamState :: FetchProbeBatch ;
14031433 self . build_side = BuildSide :: Ready ( BuildSideReadyState { left_data } ) ;
14041434
0 commit comments