Skip to content

Commit

Permalink
Fix ORC reader when using device_read_async while the destination d…
Browse files Browse the repository at this point in the history
…evice buffers are not ready (#17074)

This fixes a bug in ORC reader when `device_read_async` is called while the destination device buffers are not ready to write in. In detail, this bug is because `device_read_async` does not use the user-provided stream but its own generated stream for data copying. As such, the copying ops could happen before the destination device buffers are being allocated, causing data corruption.

This bug only shows up in certain conditions, and also hard to reproduce. It occurs when copying buffers with small sizes (below `gds_threshold`) and most likely to show up with setting `rmm_mode=async`.

Authors:
  - Nghia Truong (https://github.com/ttnghia)

Approvers:
  - Vukasin Milovanovic (https://github.com/vuule)
  - David Wendt (https://github.com/davidwendt)

URL: #17074
  • Loading branch information
ttnghia authored Oct 14, 2024
1 parent e41dea9 commit 768fbaa
Showing 1 changed file with 9 additions and 0 deletions.
9 changes: 9 additions & 0 deletions cpp/src/io/orc/reader_impl_chunking.cu
Original file line number Diff line number Diff line change
Expand Up @@ -500,13 +500,22 @@ void reader_impl::load_next_stripe_data(read_mode mode)
auto const [read_begin, read_end] =
merge_selected_ranges(_file_itm_data.stripe_data_read_ranges, load_stripe_range);

bool stream_synchronized{false};

for (auto read_idx = read_begin; read_idx < read_end; ++read_idx) {
auto const& read_info = _file_itm_data.data_read_info[read_idx];
auto const source_ptr = _metadata.per_file_metadata[read_info.source_idx].source;
auto const dst_base = static_cast<uint8_t*>(
lvl_stripe_data[read_info.level][read_info.stripe_idx - stripe_start].data());

if (source_ptr->is_device_read_preferred(read_info.length)) {
// `device_read_async` may not use _stream at all.
// Instead, it may use some other stream(s) to sync the H->D memcpy.
// As such, we need to make sure the device buffers in `lvl_stripe_data` are ready first.
if (!stream_synchronized) {
_stream.synchronize();
stream_synchronized = true;
}
device_read_tasks.push_back(
std::pair(source_ptr->device_read_async(
read_info.offset, read_info.length, dst_base + read_info.dst_pos, _stream),
Expand Down

0 comments on commit 768fbaa

Please sign in to comment.