Skip to content

Commit

Permalink
apacheGH-36092: [C++] Simplify concurrency in as-of-join node
Browse files Browse the repository at this point in the history
  • Loading branch information
rtpsw committed Jun 15, 2023
1 parent 475b5b9 commit 582f572
Showing 1 changed file with 10 additions and 14 deletions.
24 changes: 10 additions & 14 deletions cpp/src/arrow/acero/asof_join_node.cc
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,7 @@ struct MemoStore {
// the time of the current entry, defaulting to 0.
// when entries with a time less than T are removed, the current time is updated to the
// time of the next (by-time) and now-current entry or to T if no such entry exists.
std::atomic<OnType> current_time_;
OnType current_time_;
// current entry per key
std::unordered_map<ByType, Entry> entries_;
// future entries per key
Expand All @@ -364,21 +364,16 @@ struct MemoStore {
std::swap(index_, memo.index_);
#endif
std::swap(no_future_, memo.no_future_);
current_time_ = memo.current_time_.exchange(static_cast<OnType>(current_time_));
std::swap(current_time_, memo.current_time_);
entries_.swap(memo.entries_);
future_entries_.swap(memo.future_entries_);
times_.swap(memo.times_);
}

// Updates the current time to `ts` if it is less. A different thread may win the race
// to update the current time to more than `ts` but not to less. Returns whether the
// current time was changed from its value at the beginning of this invocation.
// Updates the current time to `ts` if it is less. Returns true if updated.
bool UpdateTime(OnType ts) {
OnType prev_time = current_time_;
bool update = prev_time < ts;
while (prev_time < ts && !current_time_.compare_exchange_weak(prev_time, ts)) {
// intentionally empty - standard CAS loop
}
bool update = current_time_ < ts;
if (update) current_time_ = ts;
return update;
}

Expand Down Expand Up @@ -529,7 +524,7 @@ class KeyHasher {
size_t index_;
std::vector<col_index_t> indices_;
std::vector<KeyColumnMetadata> metadata_;
std::atomic<const RecordBatch*> batch_;
const RecordBatch* batch_;
std::vector<HashType> hashes_;
LightContext ctx_;
std::vector<KeyColumnArray> column_arrays_;
Expand Down Expand Up @@ -821,8 +816,11 @@ class InputState {
++batches_processed_;
latest_ref_row_ = 0;
have_active_batch &= !queue_.TryPop();
if (have_active_batch)
if (have_active_batch) {
DCHECK_GT(queue_.UnsyncFront()->num_rows(), 0); // empty batches disallowed
key_hasher_->Invalidate(); // batch changed - invalidate key hasher's cache
memo_.UpdateTime(GetTime(queue_.UnsyncFront().get(), 0)); // time changed
}
}
}
return have_active_batch;
Expand Down Expand Up @@ -898,8 +896,6 @@ class InputState {

Status Push(const std::shared_ptr<arrow::RecordBatch>& rb) {
if (rb->num_rows() > 0) {
key_hasher_->Invalidate(); // batch changed - invalidate key hasher's cache
memo_.UpdateTime(GetTime(rb.get(), 0)); // time changed - update in MemoStore
queue_.Push(rb); // only after above updates - push batch for processing
} else {
++batches_processed_; // don't enqueue empty batches, just record as processed
Expand Down

0 comments on commit 582f572

Please sign in to comment.