Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: add more comments for hash join and remove unwrap #12985

Merged
merged 1 commit into from
Sep 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ use std::sync::Arc;
use common_base::base::tokio::sync::watch;
use common_base::base::tokio::sync::watch::Receiver;
use common_base::base::tokio::sync::watch::Sender;
use common_exception::ErrorCode;
use common_exception::Result;
use common_expression::DataBlock;
use log::info;
Expand Down Expand Up @@ -72,7 +73,9 @@ impl BuildSpillCoordinator {
// If current waiting spilling builder is the last one, then spill all builders.
pub(crate) fn wait_spill(&self) -> Result<bool> {
if *self.dummy_ready_spill_receiver.borrow() {
self.ready_spill_watcher.send(false).unwrap();
self.ready_spill_watcher
.send(false)
.map_err(|_| ErrorCode::TokioError("ready_spill_watcher channel is closed"))?;
}
self.waiting_spill_count.fetch_add(1, Ordering::SeqCst);
let waiting_spill_count = self.waiting_spill_count.load(Ordering::Relaxed);
Expand Down Expand Up @@ -107,7 +110,9 @@ impl BuildSpillCoordinator {
if *rx.borrow() {
return Ok(());
}
rx.changed().await.unwrap();
rx.changed().await.map_err(|_| {
ErrorCode::TokioError("ready_spill_watcher channel's sender is dropped")
})?;
debug_assert!(*rx.borrow());
Ok(())
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,6 @@ pub struct HashJoinBuildState {
pub(crate) chunk_size_limit: Arc<usize>,
/// Wait util all processors finish row space build, then go to next phase.
pub(crate) barrier: Barrier,
/// Wait all processors finish read spilled data, then go to new round build
pub(crate) restore_barrier: Barrier,
/// It will be increased by 1 when a new hash join build processor is created.
/// After the processor put its input data into `RowSpace`, it will be decreased by 1.
/// The processor will wait other processors to finish their work before starting to build hash table.
Expand All @@ -95,7 +93,12 @@ pub struct HashJoinBuildState {
pub(crate) build_worker_num: Arc<AtomicU32>,
/// Tasks for building hash table.
pub(crate) build_hash_table_tasks: Arc<RwLock<VecDeque<(usize, usize)>>>,

/// Spill related states
/// `send_val` is the message which will be send into `build_done_watcher` channel.
pub(crate) send_val: AtomicU8,
/// Wait all processors finish read spilled data, then go to new round build
pub(crate) restore_barrier: Barrier,
}

impl HashJoinBuildState {
Expand Down Expand Up @@ -251,7 +254,7 @@ impl HashJoinBuildState {
self.hash_join_state
.build_done_watcher
.send(self.send_val.load(Ordering::Relaxed))
.unwrap();
.map_err(|_| ErrorCode::TokioError("build_done_watcher channel is closed"))?;
return Ok(());
}

Expand Down Expand Up @@ -627,7 +630,7 @@ impl HashJoinBuildState {
self.hash_join_state
.build_done_watcher
.send(self.send_val.load(Ordering::Relaxed))
.unwrap();
.map_err(|_| ErrorCode::TokioError("build_done_watcher channel is closed"))?;
}
Ok(())
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,14 +65,8 @@ pub struct HashJoinProbeState {
/// (Note: it doesn't mean the processor has finished its work, it just means it has finished probe hash table.)
/// When the counter is 0, processors will go to next phase's work
pub(crate) probe_workers: Mutex<usize>,
/// Record spill workers
pub(crate) spill_workers: Mutex<usize>,
/// Record final probe workers
pub(crate) final_probe_workers: Mutex<usize>,
/// Wait all `probe_workers` finish
pub(crate) barrier: Barrier,
/// Wait all processors to restore spilled data, then go to new probe
pub(crate) restore_barrier: Barrier,
/// The schema of probe side.
pub(crate) probe_schema: DataSchemaRef,
/// `probe_projections` only contains the columns from upstream required columns
Expand All @@ -84,8 +78,16 @@ pub struct HashJoinProbeState {
pub(crate) mark_scan_map_lock: Mutex<bool>,
/// Hash method
pub(crate) hash_method: HashMethodKind,
/// Spilled partitions set

/// Spill related states
xudong963 marked this conversation as resolved.
Show resolved Hide resolved
/// Record spill workers
pub(crate) spill_workers: Mutex<usize>,
/// Record final probe workers
pub(crate) final_probe_workers: Mutex<usize>,
/// Probe spilled partitions set
pub(crate) spill_partitions: Arc<RwLock<HashSet<u8>>>,
/// Wait all processors to restore spilled data, then go to new probe
pub(crate) restore_barrier: Barrier,
}

impl HashJoinProbeState {
Expand Down Expand Up @@ -276,7 +278,7 @@ impl HashJoinProbeState {
Ok(res)
}

pub fn finish_final_probe(&self) {
pub fn finish_final_probe(&self) -> Result<()> {
let mut count = self.final_probe_workers.lock();
*count -= 1;
if *count == 0 {
Expand All @@ -297,8 +299,9 @@ impl HashJoinProbeState {
self.hash_join_state
.continue_build_watcher
.send(true)
.unwrap();
.map_err(|_| ErrorCode::TokioError("continue_build_watcher channel is closed"))?;
}
Ok(())
}

pub fn probe_done(&self) -> Result<()> {
Expand All @@ -311,7 +314,7 @@ impl HashJoinProbeState {
Ok(())
}

pub fn finish_spill(&self) {
pub fn finish_spill(&self) -> Result<()> {
let mut count = self.final_probe_workers.lock();
*count -= 1;
let mut count = self.spill_workers.lock();
Expand All @@ -333,8 +336,9 @@ impl HashJoinProbeState {
self.hash_join_state
.continue_build_watcher
.send(true)
.unwrap();
.map_err(|_| ErrorCode::TokioError("continue_build_watcher channel is closed"))?;
}
Ok(())
}

pub fn generate_final_scan_task(&self) -> Result<()> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ use std::sync::Arc;
use common_base::base::tokio::sync::watch;
use common_base::base::tokio::sync::watch::Receiver;
use common_base::base::tokio::sync::watch::Sender;
use common_exception::ErrorCode;
use common_exception::Result;
use common_expression::types::DataType;
use common_expression::ColumnVec;
Expand Down Expand Up @@ -82,6 +83,10 @@ pub struct HashJoinState {
/// And the build phase is finished. Probe phase will start.
pub(crate) hash_table_builders: Mutex<usize>,
/// After `hash_table_builders` is 0, send message to notify all probe processors.
/// There are three types' messages:
/// 1. **0**: it's the initial message used by creating the watch channel
/// 2. **1**: when build side finish (the first round), the last build processor will send 1 to channel, and wake up all probe processors.
/// 3. **2**: if spill is enabled, after the first round, probe needs to wait build again, the last build processor will send 2 to channel.
pub(crate) build_done_watcher: Sender<u8>,
/// A dummy receiver to make build done watcher channel open
pub(crate) _build_done_dummy_receiver: Receiver<u8>,
Expand All @@ -106,9 +111,12 @@ pub struct HashJoinState {
pub(crate) outer_scan_map: Arc<SyncUnsafeCell<Vec<Vec<bool>>>>,
/// LeftMarkScan map, initialized at `HashJoinBuildState`, used in `HashJoinProbeState`
pub(crate) mark_scan_map: Arc<SyncUnsafeCell<Vec<Vec<u8>>>>,

/// Spill related states
/// Spill partition set
pub(crate) spill_partition: Arc<RwLock<HashSet<u8>>>,
pub(crate) build_spilled_partitions: Arc<RwLock<HashSet<u8>>>,
/// Send message to notify all build processors to next round.
/// Initial message is false, send true to wake up all build processors.
pub(crate) continue_build_watcher: Sender<bool>,
/// A dummy receiver to make continue build watcher channel open
pub(crate) _continue_build_dummy_receiver: Receiver<bool>,
Expand Down Expand Up @@ -154,7 +162,7 @@ impl HashJoinState {
is_build_projected: Arc::new(AtomicBool::new(true)),
outer_scan_map: Arc::new(SyncUnsafeCell::new(Vec::new())),
mark_scan_map: Arc::new(SyncUnsafeCell::new(Vec::new())),
spill_partition: Default::default(),
build_spilled_partitions: Default::default(),
continue_build_watcher,
_continue_build_dummy_receiver,
partition_id: Arc::new(RwLock::new(-2)),
Expand All @@ -165,25 +173,31 @@ impl HashJoinState {
self.interrupt.store(true, Ordering::Release);
}

/// Used by hash join probe processors, wait for build phase finished.
/// Used by hash join probe processors, wait for the first round build phase finished.
#[async_backtrace::framed]
pub async fn wait_build_hash_table_finish(&self) -> Result<()> {
pub async fn wait_first_round_build_done(&self) -> Result<()> {
let mut rx = self.build_done_watcher.subscribe();
if *rx.borrow() == 1_u8 {
return Ok(());
}
rx.changed().await.unwrap();
rx.changed()
.await
.map_err(|_| ErrorCode::TokioError("build_done_watcher's sender is dropped"))?;
debug_assert!(*rx.borrow() == 1_u8);
Ok(())
}

/// Used by hash join probe processors, wait for build phase finished with spilled data
/// It's only be used when spilling is enabled.
#[async_backtrace::framed]
pub async fn wait_build_finish(&self) -> Result<()> {
let mut rx = self.build_done_watcher.subscribe();
if *rx.borrow() == 2_u8 {
return Ok(());
}
rx.changed().await.unwrap();
rx.changed()
.await
.map_err(|_| ErrorCode::TokioError("build_done_watcher's sender is dropped"))?;
debug_assert!(*rx.borrow() == 2_u8);
Ok(())
}
Expand All @@ -209,7 +223,7 @@ impl HashJoinState {
}

pub fn set_spilled_partition(&self, partitions: &HashSet<u8>) {
let mut spill_partition = self.spill_partition.write();
let mut spill_partition = self.build_spilled_partitions.write();
*spill_partition = partitions.clone();
}

Expand All @@ -219,7 +233,9 @@ impl HashJoinState {
if *rx.borrow() {
return Ok(());
}
rx.changed().await.unwrap();
rx.changed()
.await
.map_err(|_| ErrorCode::TokioError("continue_build_watcher's sender is dropped"))?;
debug_assert!(*rx.borrow());
Ok(())
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ use std::any::Any;
use std::sync::atomic::Ordering;
use std::sync::Arc;

use common_exception::ErrorCode;
use common_exception::Result;
use common_expression::DataBlock;
use log::info;
Expand Down Expand Up @@ -49,16 +50,16 @@ enum HashJoinBuildStep {

pub struct TransformHashJoinBuild {
input_port: Arc<InputPort>,

input_data: Option<DataBlock>,
step: HashJoinBuildStep,
build_state: Arc<HashJoinBuildState>,
spill_state: Option<Box<BuildSpillState>>,
spill_data: Option<DataBlock>,
finalize_finished: bool,
processor_id: usize,

// The flag indicates whether data is from spilled data.
from_spill: bool,
processor_id: usize,
spill_state: Option<Box<BuildSpillState>>,
spill_data: Option<DataBlock>,
}

impl TransformHashJoinBuild {
Expand Down Expand Up @@ -104,7 +105,7 @@ impl TransformHashJoinBuild {
.spill_coordinator
.ready_spill_watcher
.send(true)
.unwrap();
.map_err(|_| ErrorCode::TokioError("ready_spill_watcher channel is closed"))?;
self.step = HashJoinBuildStep::FirstSpill;
}
Ok(())
Expand All @@ -119,11 +120,13 @@ impl TransformHashJoinBuild {
let mut count = self.build_state.row_space_builders.lock();
if *count == 0 {
self.build_state.send_val.store(2, Ordering::Relaxed);
// Before build processors into `WaitProbe` state, set the channel message to false.
// Then after all probe processors are ready, the last one will send true to channel and wake up all build processors.
self.build_state
.hash_join_state
.continue_build_watcher
.send(false)
.unwrap();
.map_err(|_| ErrorCode::TokioError("continue_build_watcher channel is closed"))?;
let worker_num = self.build_state.build_worker_num.load(Ordering::Relaxed) as usize;
*count = worker_num;
let mut count = self.build_state.hash_join_state.hash_table_builders.lock();
Expand Down Expand Up @@ -165,7 +168,11 @@ impl Processor for TransformHashJoinBuild {
let mut spill_task = spill_coordinator.spill_tasks.lock();
spill_state.split_spill_tasks(spill_coordinator.active_processor_num(), &mut spill_task)?;
spill_coordinator.waiting_spill_count.store(0, Ordering::Relaxed);
spill_coordinator.ready_spill_watcher.send(true).unwrap();
spill_coordinator.ready_spill_watcher.send(true).map_err(|_| {
ErrorCode::TokioError(
"ready_spill_watcher channel is closed",
)
})?;
}
}
self.build_state.row_space_build_done()?;
Expand Down Expand Up @@ -329,7 +336,9 @@ impl Processor for TransformHashJoinBuild {
.hash_join_state
.build_done_watcher
.send(2)
.unwrap();
.map_err(|_| {
ErrorCode::TokioError("build_done_watcher channel is closed")
})?;
self.step = HashJoinBuildStep::Finished;
return Ok(());
}
Expand Down
Loading