Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Batch memcpy the last offsets for output buffers of str and list cols in PQ reader #16905

Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
74ee6ae
Add capability to batch memcpy the last offsets to str and list out_bufs
mhaseeb123 Sep 25, 2024
cab885d
Move `WriteFinalOffsetsBatched` out of the for loop
mhaseeb123 Sep 25, 2024
b15e3d3
Generalize the API and ORC changes by @vuule
mhaseeb123 Sep 25, 2024
50dcd71
Use make_zeroed_device_uvector_async instead
mhaseeb123 Sep 25, 2024
bd44ca0
Merge branch 'branch-24.12' into fea-batch-memcpy-list-str-output-buf…
mhaseeb123 Sep 26, 2024
800b271
Add gtest for batched_memcpy
mhaseeb123 Sep 26, 2024
31a755b
Update cpp/include/cudf/io/detail/batched_memcpy.hpp
mhaseeb123 Sep 26, 2024
b29329b
Update cpp/include/cudf/io/detail/batched_memcpy.hpp
mhaseeb123 Sep 26, 2024
4efb989
Comments update
mhaseeb123 Sep 26, 2024
cc2829f
Address reviewer comments
mhaseeb123 Sep 27, 2024
78d68a8
Style fix
mhaseeb123 Sep 27, 2024
d42da45
Remove the unnecessary iterator
mhaseeb123 Sep 27, 2024
8d5640d
Move batched_memxxx to include/detail/utilities
mhaseeb123 Sep 27, 2024
9e063af
Minor changes from reviews
mhaseeb123 Sep 30, 2024
cf98118
Merge branch 'branch-24.12' into fea-batch-memcpy-list-str-output-buf…
mhaseeb123 Sep 30, 2024
2372fbb
Minor updates
mhaseeb123 Sep 30, 2024
6100c94
Merge branch 'fea-batch-memcpy-list-str-output-buff-offsets' of https…
mhaseeb123 Sep 30, 2024
4ea0930
Minor comment update
mhaseeb123 Oct 1, 2024
3eea6e2
Minor comment update
mhaseeb123 Oct 1, 2024
6d078c2
Style fix and add to CI.
mhaseeb123 Oct 1, 2024
1cc4e1f
Revert erroneous commit
mhaseeb123 Oct 1, 2024
042cfc0
Update cpp/include/cudf/detail/utilities/batched_memcpy.hpp
mhaseeb123 Oct 2, 2024
eee6f6d
Apply suggestions from review
mhaseeb123 Oct 2, 2024
828e0ac
Minor updates from review
mhaseeb123 Oct 2, 2024
ecc4252
Minor
mhaseeb123 Oct 2, 2024
4bd83db
Merge branch 'branch-24.12' into fea-batch-memcpy-list-str-output-buf…
mhaseeb123 Oct 2, 2024
871854b
Update cpp/src/io/parquet/page_data.cu
mhaseeb123 Oct 3, 2024
3e30777
Comments update.
mhaseeb123 Oct 3, 2024
16540a1
Merge branch 'branch-24.12' into fea-batch-memcpy-list-str-output-buf…
mhaseeb123 Oct 3, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 69 additions & 0 deletions cpp/include/cudf/io/detail/batched_memcpy.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
/*
* Copyright (c) 2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <cudf/detail/iterator.cuh>
#include <cudf/utilities/memory_resource.hpp>

#include <rmm/cuda_stream_view.hpp>
#include <rmm/device_buffer.hpp>

#include <cub/device/device_memcpy.cuh>
#include <cuda/functional>
#include <thrust/iterator/constant_iterator.h>

namespace CUDF_EXPORT cudf {
namespace io::detail {
mhaseeb123 marked this conversation as resolved.
Show resolved Hide resolved

/**
* @brief A helper function that copies a vector of vectors from source to destination addresses in
* a batched manner.
*
* @tparam SrcIterator The type of the source address iterator
* @tparam DstIterator The type of the destination address iterator
* @tparam Sizeiterator The type of the buffer size iterator
mhaseeb123 marked this conversation as resolved.
Show resolved Hide resolved
*
* @param[in] src_iter Iterator to source addresses
* @param[in] dst_iter Iterator to destination addresses
* @param[in] size_iter Iterator to the vector sizes (in bytes)
vuule marked this conversation as resolved.
Show resolved Hide resolved
* @param[in] stream CUDA stream to use
*/
template <typename SrcIterator, typename DstIterator, typename Sizeiterator>
void batched_memcpy(SrcIterator src_iter,
mhaseeb123 marked this conversation as resolved.
Show resolved Hide resolved
DstIterator dst_iter,
Sizeiterator size_iter,
mhaseeb123 marked this conversation as resolved.
Show resolved Hide resolved
size_t num_elems,
rmm::cuda_stream_view stream)
{
// Get temp storage needed for cub::DeviceMemcpy::Batched
size_t temp_storage_bytes = 0;
cub::DeviceMemcpy::Batched(
nullptr, temp_storage_bytes, src_iter, dst_iter, size_iter, num_elems, stream.value());

// Allocate temporary storage
rmm::device_buffer d_temp_storage{temp_storage_bytes, stream.value()};

// Run cub::DeviceMemcpy::Batched
mhaseeb123 marked this conversation as resolved.
Show resolved Hide resolved
cub::DeviceMemcpy::Batched(d_temp_storage.data(),
temp_storage_bytes,
src_iter,
dst_iter,
size_iter,
num_elems,
stream.value());
}

} // namespace io::detail
} // namespace CUDF_EXPORT cudf
64 changes: 44 additions & 20 deletions cpp/src/io/orc/stripe_enc.cu
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include <cudf/detail/utilities/integer_utils.hpp>
#include <cudf/detail/utilities/logger.hpp>
#include <cudf/detail/utilities/vector_factories.hpp>
#include <cudf/io/detail/batched_memcpy.hpp>
#include <cudf/io/orc_types.hpp>
#include <cudf/lists/lists_column_view.hpp>
#include <cudf/utilities/bit.hpp>
Expand Down Expand Up @@ -1087,37 +1088,42 @@ CUDF_KERNEL void __launch_bounds__(block_size)
/**
* @brief Merge chunked column data into a single contiguous stream
*
* @param[in,out] strm_desc StripeStream device array [stripe][stream]
* @param[in,out] streams List of encoder chunk streams [column][rowgroup]
* @param[in] strm_desc StripeStream device array [stripe][stream]
* @param[in] streams List of encoder chunk streams [column][rowgroup]
* @param[out] srcs List of source encoder chunk stream data addresses
* @param[out] dsts List of destination StripeStream data addresses
* @param[out] sizes List of stream sizes in bytes
*/
// blockDim {compact_streams_block_size,1,1}
CUDF_KERNEL void __launch_bounds__(compact_streams_block_size)
gpuCompactOrcDataStreams(device_2dspan<StripeStream> strm_desc,
device_2dspan<encoder_chunk_streams> streams)
gpuInitBatchedMemcpy(device_2dspan<StripeStream const> strm_desc,
device_2dspan<encoder_chunk_streams> streams,
device_span<uint8_t*> srcs,
device_span<uint8_t*> dsts,
device_span<size_t> sizes)
{
__shared__ __align__(16) StripeStream ss;

auto const stripe_id = blockIdx.x;
auto const stripe_id = blockIdx.x * compact_streams_block_size + threadIdx.x;
mhaseeb123 marked this conversation as resolved.
Show resolved Hide resolved
auto const stream_id = blockIdx.y;
auto const t = threadIdx.x;
if (stripe_id >= strm_desc.size().first) { return; }

if (t == 0) { ss = strm_desc[stripe_id][stream_id]; }
__syncthreads();
auto const out_id = stream_id * strm_desc.size().first + stripe_id;
StripeStream ss = strm_desc[stripe_id][stream_id];

if (ss.data_ptr == nullptr) { return; }

auto const cid = ss.stream_type;
auto dst_ptr = ss.data_ptr;
for (auto group = ss.first_chunk_id; group < ss.first_chunk_id + ss.num_chunks; ++group) {
auto const out_id = stream_id * streams.size().second + group;
srcs[out_id] = streams[ss.column_id][group].data_ptrs[cid];
dsts[out_id] = dst_ptr;

// Also update the stream here, data will be copied in a separate kernel
streams[ss.column_id][group].data_ptrs[cid] = dst_ptr;

auto const len = streams[ss.column_id][group].lengths[cid];
if (len > 0) {
auto const src_ptr = streams[ss.column_id][group].data_ptrs[cid];
for (uint32_t i = t; i < len; i += blockDim.x) {
dst_ptr[i] = src_ptr[i];
}
__syncthreads();
}
if (t == 0) { streams[ss.column_id][group].data_ptrs[cid] = dst_ptr; }
// Multiplying by sizeof(uint8_t) = 1 is redundant here.
sizes[out_id] = len;
dst_ptr += len;
}
}
Expand Down Expand Up @@ -1326,8 +1332,26 @@ void CompactOrcDataStreams(device_2dspan<StripeStream> strm_desc,
rmm::cuda_stream_view stream)
{
dim3 dim_block(compact_streams_block_size, 1);
mhaseeb123 marked this conversation as resolved.
Show resolved Hide resolved
dim3 dim_grid(strm_desc.size().first, strm_desc.size().second);
gpuCompactOrcDataStreams<<<dim_grid, dim_block, 0, stream.value()>>>(strm_desc, enc_streams);

auto const num_rowgroups = enc_streams.size().second;
auto const num_streams = strm_desc.size().second;
auto const num_stripes = strm_desc.size().first;
auto const num_chunks = num_rowgroups * num_streams;
auto srcs = cudf::detail::make_zeroed_device_uvector_async<uint8_t*>(
num_chunks, stream, rmm::mr::get_current_device_resource());
auto dsts = cudf::detail::make_zeroed_device_uvector_async<uint8_t*>(
num_chunks, stream, rmm::mr::get_current_device_resource());
auto lengths = cudf::detail::make_zeroed_device_uvector_async<size_t>(
num_chunks, stream, rmm::mr::get_current_device_resource());

dim3 dim_grid_alt(cudf::util::div_rounding_up_unsafe(num_stripes, compact_streams_block_size),
strm_desc.size().second);
gpuInitBatchedMemcpy<<<dim_grid_alt, dim_block, 0, stream.value()>>>(
mhaseeb123 marked this conversation as resolved.
Show resolved Hide resolved
strm_desc, enc_streams, srcs, dsts, lengths);

// Copy streams in a batched manner.
cudf::io::detail::batched_memcpy(
srcs.data(), dsts.data(), lengths.data(), lengths.size(), stream);
}

std::optional<writer_compression_statistics> CompressOrcDataStreams(
Expand Down
31 changes: 31 additions & 0 deletions cpp/src/io/parquet/page_data.cu
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
#include "page_data.cuh"
#include "page_decode.cuh"

#include <cudf/io/detail/batched_memcpy.hpp>

#include <rmm/exec_policy.hpp>

#include <thrust/reduce.h>
Expand Down Expand Up @@ -466,4 +468,33 @@ void __host__ DecodeSplitPageData(cudf::detail::hostdevice_span<PageInfo> pages,
}
}

/**
* @copydoc cudf::io::parquet::detail::WriteOutputBufferSizesBatched
*/
void __host__ WriteFinalOffsetsBatched(std::vector<size_type> const& offsets,
mhaseeb123 marked this conversation as resolved.
Show resolved Hide resolved
std::vector<size_type*> const& buff_addrs,
rmm::device_async_resource_ref mr,
rmm::cuda_stream_view stream)
{
// Copy offsets to device and create an iterator
auto d_src_data = cudf::detail::make_device_uvector_async(offsets, stream, mr);
auto src_iter = cudf::detail::make_counting_transform_iterator(
vuule marked this conversation as resolved.
Show resolved Hide resolved
static_cast<std::size_t>(0),
cuda::proclaim_return_type<size_type*>(
[src = d_src_data.data()] __device__(std::size_t i) { return src + i; }));

// Copy buffer addresses to device and create an iterator
auto d_dst_addrs = cudf::detail::make_device_uvector_async(buff_addrs, stream, mr);
auto dst_iter = cudf::detail::make_counting_transform_iterator(
mhaseeb123 marked this conversation as resolved.
Show resolved Hide resolved
static_cast<std::size_t>(0),
cuda::proclaim_return_type<size_type*>(
[dst = d_dst_addrs.data()] __device__(std::size_t i) { return dst[i]; }));

// size_iter is simply a constant iterator of sizeof(size_type) bytes.
auto size_iter = thrust::make_constant_iterator(sizeof(size_type));

// Copy offsets to buffers in batched manner.
cudf::io::detail::batched_memcpy(src_iter, dst_iter, size_iter, offsets.size(), stream);
}

} // namespace cudf::io::parquet::detail
15 changes: 15 additions & 0 deletions cpp/src/io/parquet/parquet_gpu.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -796,6 +796,21 @@ void DecodeSplitPageData(cudf::detail::hostdevice_span<PageInfo> pages,
kernel_error::pointer error_code,
rmm::cuda_stream_view stream);

/**
* @brief Writes the final offsets to the corresponding list and string buffer end addresses in a
* batched manner.
*
*
mhaseeb123 marked this conversation as resolved.
Show resolved Hide resolved
* @param[in] offsets A vector of finals offsets
* @param[in] buff_addrs A vector of corresponding buffer end addresses
* @param[in] mr Device memory resource to allocate temporary memory
* @param[in] stream CUDA stream to use
*/
void WriteFinalOffsetsBatched(std::vector<size_type> const& offsets,
std::vector<size_type*> const& buff_addrs,
mhaseeb123 marked this conversation as resolved.
Show resolved Hide resolved
rmm::device_async_resource_ref mr,
mhaseeb123 marked this conversation as resolved.
Show resolved Hide resolved
rmm::cuda_stream_view stream);

/**
* @brief Launches kernel for reading the string column data stored in the pages
*
Expand Down
20 changes: 10 additions & 10 deletions cpp/src/io/parquet/reader_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -378,6 +378,10 @@ void reader::impl::decode_page_data(read_mode mode, size_t skip_rows, size_t num
// that it is difficult/impossible for a given page to know that it is writing the very
// last value that should then be followed by a terminator (because rows can span
// page boundaries).
std::vector<size_type*> out_buffers;
std::vector<size_type> final_offsets;
out_buffers.reserve(_input_columns.size());
final_offsets.reserve(_input_columns.size());
for (size_t idx = 0; idx < _input_columns.size(); idx++) {
input_column_info const& input_col = _input_columns[idx];

Expand All @@ -393,25 +397,21 @@ void reader::impl::decode_page_data(read_mode mode, size_t skip_rows, size_t num

// the final offset for a list at level N is the size of it's child
size_type const offset = child.type.id() == type_id::LIST ? child.size - 1 : child.size;
CUDF_CUDA_TRY(cudaMemcpyAsync(static_cast<size_type*>(out_buf.data()) + (out_buf.size - 1),
&offset,
sizeof(size_type),
cudaMemcpyDefault,
_stream.value()));
out_buffers.emplace_back(static_cast<size_type*>(out_buf.data()) + (out_buf.size - 1));
final_offsets.emplace_back(offset);
out_buf.user_data |= PARQUET_COLUMN_BUFFER_FLAG_LIST_TERMINATED;
} else if (out_buf.type.id() == type_id::STRING) {
// need to cap off the string offsets column
auto const sz = static_cast<size_type>(col_string_sizes[idx]);
if (sz <= strings::detail::get_offset64_threshold()) {
CUDF_CUDA_TRY(cudaMemcpyAsync(static_cast<size_type*>(out_buf.data()) + out_buf.size,
&sz,
sizeof(size_type),
cudaMemcpyDefault,
_stream.value()));
out_buffers.emplace_back(static_cast<size_type*>(out_buf.data()) + out_buf.size);
final_offsets.emplace_back(sz);
}
}
}
}
// Write the final offsets for all list and string buffers in a batched manner
cudf::io::parquet::detail::WriteFinalOffsetsBatched(final_offsets, out_buffers, _mr, _stream);

// update null counts in the final column buffers
for (size_t idx = 0; idx < subpass.pages.size(); idx++) {
Expand Down
3 changes: 2 additions & 1 deletion cpp/tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -385,6 +385,8 @@ ConfigureTest(
# * utilities tests -------------------------------------------------------------------------------
ConfigureTest(
UTILITIES_TEST
utilities_tests/batched_memcpy_tests.cu
utilities_tests/batched_memset_tests.cu
utilities_tests/column_debug_tests.cpp
utilities_tests/column_utilities_tests.cpp
utilities_tests/column_wrapper_tests.cpp
Expand All @@ -395,7 +397,6 @@ ConfigureTest(
utilities_tests/pinned_memory_tests.cpp
utilities_tests/type_check_tests.cpp
utilities_tests/type_list_tests.cpp
utilities_tests/batched_memset_tests.cu
)

# ##################################################################################################
Expand Down
Loading
Loading