Skip to content

Commit

Permalink
Added batch memset to memset data and validity buffers in parquet rea…
Browse files Browse the repository at this point in the history
…der (#16281)

Under some situations in the Parquet reader (particularly the case with tables containing many columns or deeply nested column) we burn a decent amount of time doing cudaMemset() operations on output buffers. A good amount of this overhead seems to stem from the fact that we're simply launching many tiny kernels. This PR adds a batched memset kernel that takes a list of device spans as a single input and does all the work under a single kernel launch. This PR addresses issue #15773 

## Improvements
Using out performance cluster, improvements of 2.39% were shown on running the overall NDS queries
Additionally, benchmarks were added showing big improvements(around 20%) especially on fixed width data types which can be shown below

data_type | num_cols | cardinality | run_length | bytes_per_second_before_this_pr | bytes_per_second_after_this_pr | speedup
--- | --- | --- | --- | --- | --- | ---
INTEGRAL | 1000 | 0 | 1 | 36514934834 | 42756531566 | 1.170932709
INTEGRAL | 1000 | 1000 | 1 | 35364061247 | 39112512476 | 1.105996062
INTEGRAL | 1000 | 0 | 32 | 37349112510 | 39641370858 | 1.061373837
INTEGRAL | 1000 | 1000 | 32 | 39167079622 | 43740824957 | 1.116775245
FLOAT | 1000 | 0 | 1 | 51877322003 | 64083898838 | 1.235296973
FLOAT | 1000 | 1000 | 1 | 48983612272 | 58705522023 | 1.198472699
FLOAT | 1000 | 0 | 32 | 46544977658 | 53715018581 | 1.154045426
FLOAT | 1000 | 1000 | 32 | 54493432148 | 66617609904 | 1.22248879
DECIMAL | 1000 | 0 | 1 | 47616412888 | 57952310685 | 1.217065864
DECIMAL | 1000 | 1000 | 1 | 47166138095 | 54283772484 | 1.1509056
DECIMAL | 1000 | 0 | 32 | 45266163387 | 53770390830 | 1.18787162
DECIMAL | 1000 | 1000 | 32 | 52292176603 | 58847723569 | 1.125363819
TIMESTAMP | 1000 | 0 | 1 | 50245415328 | 60797982330 | 1.210020495
TIMESTAMP | 1000 | 1000 | 1 | 50300238706 | 60810368331 | 1.208947908
TIMESTAMP | 1000 | 0 | 32 | 55338354243 | 66786275739 | 1.206871376
TIMESTAMP | 1000 | 1000 | 32 | 55680028082 | 69029227374 | 1.23974843
DURATION | 1000 | 0 | 1 | 54680007758 | 66855201896 | 1.222662626
DURATION | 1000 | 1000 | 1 | 54305832171 | 66602436269 | 1.226432477
DURATION | 1000 | 0 | 32 | 60040760815 | 72663056969 | 1.210228784
DURATION | 1000 | 1000 | 32 | 60212221703 | 75646396131 | 1.256329595
STRING | 1000 | 0 | 1 | 29691707753 | 33388700976 | 1.12451265
STRING | 1000 | 1000 | 1 | 31411129876 | 35407241037 | 1.127219593
STRING | 1000 | 0 | 32 | 29680479388 | 33382478907 | 1.124728427
STRING | 1000 | 1000 | 32 | 35476213777 | 40478389269 | 1.141000827
LIST | 1000 | 0 | 1 | 6874253484 | 7370835717 | 1.072237987
LIST | 1000 | 1000 | 1 | 6763426009 | 7253762966 | 1.07249831
LIST | 1000 | 0 | 32 | 6981508808 | 7502741115 | 1.074658977
LIST | 1000 | 1000 | 32 | 6989374761 | 7506418252 | 1.073975643
STRUCT | 1000 | 0 | 1 | 2137525922 | 2189495762 | 1.024313081
STRUCT | 1000 | 1000 | 1 | 1057923939 | 1078475980 | 1.019426766
STRUCT | 1000 | 0 | 32 | 1637342446 | 1698913790 | 1.037604439
STRUCT | 1000 | 1000 | 32 | 1057587701 | 1082539399 | 1.02359303

Authors:
  - Rahul Prabhu (https://github.com/sdrp713)
  - Muhammad Haseeb (https://github.com/mhaseeb123)

Approvers:
  - https://github.com/nvdbaranec
  - Muhammad Haseeb (https://github.com/mhaseeb123)
  - Kyle Edwards (https://github.com/KyleFromNVIDIA)
  - Bradley Dice (https://github.com/bdice)

URL: #16281
  • Loading branch information
sdrp713 authored Aug 5, 2024
1 parent af57286 commit 837dfe5
Show file tree
Hide file tree
Showing 8 changed files with 353 additions and 14 deletions.
5 changes: 5 additions & 0 deletions cpp/benchmarks/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -353,6 +353,11 @@ ConfigureNVBench(JSON_READER_NVBENCH io/json/nested_json.cpp io/json/json_reader
ConfigureNVBench(JSON_READER_OPTION_NVBENCH io/json/json_reader_option.cpp)
ConfigureNVBench(JSON_WRITER_NVBENCH io/json/json_writer.cpp)

# ##################################################################################################
# * multi buffer memset benchmark
# ----------------------------------------------------------------------
ConfigureNVBench(BATCHED_MEMSET_BENCH io/utilities/batched_memset_bench.cpp)

# ##################################################################################################
# * io benchmark ---------------------------------------------------------------------
ConfigureNVBench(MULTIBYTE_SPLIT_NVBENCH io/text/multibyte_split.cpp)
Expand Down
101 changes: 101 additions & 0 deletions cpp/benchmarks/io/utilities/batched_memset_bench.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
/*
* Copyright (c) 2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <benchmarks/common/generate_input.hpp>
#include <benchmarks/fixture/benchmark_fixture.hpp>
#include <benchmarks/io/cuio_common.hpp>
#include <benchmarks/io/nvbench_helpers.hpp>

#include <cudf/io/parquet.hpp>
#include <cudf/utilities/default_stream.hpp>

#include <nvbench/nvbench.cuh>

// Size of the data in the benchmark dataframe; chosen to be low enough to allow benchmarks to
// run on most GPUs, but large enough to allow highest throughput
constexpr size_t data_size = 512 << 20;

void parquet_read_common(cudf::size_type num_rows_to_read,
cudf::size_type num_cols_to_read,
cuio_source_sink_pair& source_sink,
nvbench::state& state)
{
cudf::io::parquet_reader_options read_opts =
cudf::io::parquet_reader_options::builder(source_sink.make_source_info());

auto mem_stats_logger = cudf::memory_stats_logger();
state.set_cuda_stream(nvbench::make_cuda_stream_view(cudf::get_default_stream().value()));
state.exec(
nvbench::exec_tag::sync | nvbench::exec_tag::timer, [&](nvbench::launch& launch, auto& timer) {
try_drop_l3_cache();

timer.start();
auto const result = cudf::io::read_parquet(read_opts);
timer.stop();

CUDF_EXPECTS(result.tbl->num_columns() == num_cols_to_read, "Unexpected number of columns");
CUDF_EXPECTS(result.tbl->num_rows() == num_rows_to_read, "Unexpected number of rows");
});

auto const time = state.get_summary("nv/cold/time/gpu/mean").get_float64("value");
state.add_element_count(static_cast<double>(data_size) / time, "bytes_per_second");
state.add_buffer_size(
mem_stats_logger.peak_memory_usage(), "peak_memory_usage", "peak_memory_usage");
state.add_buffer_size(source_sink.size(), "encoded_file_size", "encoded_file_size");
}

template <data_type DataType>
void bench_batched_memset(nvbench::state& state, nvbench::type_list<nvbench::enum_type<DataType>>)
{
auto const d_type = get_type_or_group(static_cast<int32_t>(DataType));
auto const num_cols = static_cast<cudf::size_type>(state.get_int64("num_cols"));
auto const cardinality = static_cast<cudf::size_type>(state.get_int64("cardinality"));
auto const run_length = static_cast<cudf::size_type>(state.get_int64("run_length"));
auto const source_type = retrieve_io_type_enum(state.get_string("io_type"));
auto const compression = cudf::io::compression_type::NONE;
cuio_source_sink_pair source_sink(source_type);
auto const tbl =
create_random_table(cycle_dtypes(d_type, num_cols),
table_size_bytes{data_size},
data_profile_builder().cardinality(cardinality).avg_run_length(run_length));
auto const view = tbl->view();

cudf::io::parquet_writer_options write_opts =
cudf::io::parquet_writer_options::builder(source_sink.make_sink_info(), view)
.compression(compression);
cudf::io::write_parquet(write_opts);
auto const num_rows = view.num_rows();

parquet_read_common(num_rows, num_cols, source_sink, state);
}

using d_type_list = nvbench::enum_type_list<data_type::INTEGRAL,
data_type::FLOAT,
data_type::DECIMAL,
data_type::TIMESTAMP,
data_type::DURATION,
data_type::STRING,
data_type::LIST,
data_type::STRUCT>;

NVBENCH_BENCH_TYPES(bench_batched_memset, NVBENCH_TYPE_AXES(d_type_list))
.set_name("batched_memset")
.set_type_axes_names({"data_type"})
.add_int64_axis("num_cols", {1000})
.add_string_axis("io_type", {"DEVICE_BUFFER"})
.set_min_samples(4)
.add_int64_axis("cardinality", {0, 1000})
.add_int64_axis("run_length", {1, 32});
82 changes: 82 additions & 0 deletions cpp/include/cudf/io/detail/batched_memset.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
/*
* Copyright (c) 2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <cudf/detail/iterator.cuh>
#include <cudf/detail/utilities/vector_factories.hpp>

#include <rmm/cuda_stream_view.hpp>
#include <rmm/device_buffer.hpp>
#include <rmm/resource_ref.hpp>

#include <cub/device/device_copy.cuh>
#include <cuda/functional>
#include <thrust/iterator/counting_iterator.h>
#include <thrust/iterator/transform_iterator.h>
#include <thrust/transform.h>

namespace CUDF_EXPORT cudf {
namespace io::detail {

/**
* @brief A helper function that takes in a vector of device spans and memsets them to the
* value provided using batches sent to the GPU.
*
* @param bufs Vector with device spans of data
* @param value Value to memset all device spans to
* @param _stream Stream used for device memory operations and kernel launches
*
* @return The data in device spans all set to value
*/
template <typename T>
void batched_memset(std::vector<cudf::device_span<T>> const& bufs,
T const value,
rmm::cuda_stream_view stream)
{
// define task and bytes parameters
auto const num_bufs = bufs.size();

// copy bufs into device memory and then get sizes
auto gpu_bufs =
cudf::detail::make_device_uvector_async(bufs, stream, rmm::mr::get_current_device_resource());

// get a vector with the sizes of all buffers
auto sizes = cudf::detail::make_counting_transform_iterator(
static_cast<std::size_t>(0),
cuda::proclaim_return_type<std::size_t>(
[gpu_bufs = gpu_bufs.data()] __device__(std::size_t i) { return gpu_bufs[i].size(); }));

// get an iterator with a constant value to memset
auto iter_in = thrust::make_constant_iterator(thrust::make_constant_iterator(value));

// get an iterator pointing to each device span
auto iter_out = thrust::make_transform_iterator(
thrust::counting_iterator<std::size_t>(0),
cuda::proclaim_return_type<T*>(
[gpu_bufs = gpu_bufs.data()] __device__(std::size_t i) { return gpu_bufs[i].data(); }));

size_t temp_storage_bytes = 0;

cub::DeviceCopy::Batched(nullptr, temp_storage_bytes, iter_in, iter_out, sizes, num_bufs, stream);

rmm::device_buffer d_temp_storage(
temp_storage_bytes, stream, rmm::mr::get_current_device_resource());

cub::DeviceCopy::Batched(
d_temp_storage.data(), temp_storage_bytes, iter_in, iter_out, sizes, num_bufs, stream);
}

} // namespace io::detail
} // namespace CUDF_EXPORT cudf
29 changes: 26 additions & 3 deletions cpp/src/io/parquet/reader_impl_preprocess.cu
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include <cudf/detail/nvtx/ranges.hpp>
#include <cudf/detail/utilities/integer_utils.hpp>
#include <cudf/detail/utilities/vector_factories.hpp>
#include <cudf/io/detail/batched_memset.hpp>

#include <rmm/exec_policy.hpp>

Expand Down Expand Up @@ -1494,6 +1495,11 @@ void reader::impl::allocate_columns(read_mode mode, size_t skip_rows, size_t num
// buffers if they are not part of a list hierarchy. mark down
// if we have any list columns that need further processing.
bool has_lists = false;
// Casting to std::byte since data buffer pointer is void *
std::vector<cudf::device_span<std::byte>> memset_bufs;
// Validity Buffer is a uint32_t pointer
std::vector<cudf::device_span<cudf::bitmask_type>> nullmask_bufs;

for (size_t idx = 0; idx < _input_columns.size(); idx++) {
auto const& input_col = _input_columns[idx];
size_t const max_depth = input_col.nesting_depth();
Expand All @@ -1514,13 +1520,19 @@ void reader::impl::allocate_columns(read_mode mode, size_t skip_rows, size_t num
// we're going to start null mask as all valid and then turn bits off if necessary
out_buf.create_with_mask(
out_buf.type.id() == type_id::LIST && l_idx < max_depth ? num_rows + 1 : num_rows,
cudf::mask_state::ALL_VALID,
cudf::mask_state::UNINITIALIZED,
false,
_stream,
_mr);
memset_bufs.push_back(cudf::device_span<std::byte>(static_cast<std::byte*>(out_buf.data()),
out_buf.data_size()));
nullmask_bufs.push_back(cudf::device_span<cudf::bitmask_type>(
out_buf.null_mask(),
cudf::util::round_up_safe(out_buf.null_mask_size(), sizeof(cudf::bitmask_type)) /
sizeof(cudf::bitmask_type)));
}
}
}

// compute output column sizes by examining the pages of the -input- columns
if (has_lists) {
auto h_cols_info =
Expand Down Expand Up @@ -1593,11 +1605,22 @@ void reader::impl::allocate_columns(read_mode mode, size_t skip_rows, size_t num

// allocate
// we're going to start null mask as all valid and then turn bits off if necessary
out_buf.create_with_mask(size, cudf::mask_state::ALL_VALID, _stream, _mr);
out_buf.create_with_mask(size, cudf::mask_state::UNINITIALIZED, false, _stream, _mr);
memset_bufs.push_back(cudf::device_span<std::byte>(
static_cast<std::byte*>(out_buf.data()), out_buf.data_size()));
nullmask_bufs.push_back(cudf::device_span<cudf::bitmask_type>(
out_buf.null_mask(),
cudf::util::round_up_safe(out_buf.null_mask_size(), sizeof(cudf::bitmask_type)) /
sizeof(cudf::bitmask_type)));
}
}
}
}

cudf::io::detail::batched_memset(memset_bufs, static_cast<std::byte>(0), _stream);
// Need to set null mask bufs to all high bits
cudf::io::detail::batched_memset(
nullmask_bufs, std::numeric_limits<cudf::bitmask_type>::max(), _stream);
}

std::vector<size_t> reader::impl::calculate_page_string_offsets()
Expand Down
29 changes: 22 additions & 7 deletions cpp/src/io/utilities/column_buffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@

namespace cudf::io::detail {

void gather_column_buffer::allocate_strings_data(rmm::cuda_stream_view stream)
void gather_column_buffer::allocate_strings_data(bool memset_data, rmm::cuda_stream_view stream)
{
CUDF_EXPECTS(type.id() == type_id::STRING, "allocate_strings_data called for non-string column");
// The contents of _strings will never be directly returned to the user.
Expand All @@ -56,11 +56,12 @@ std::unique_ptr<column> gather_column_buffer::make_string_column_impl(rmm::cuda_
return make_strings_column(*_strings, stream, _mr);
}

void cudf::io::detail::inline_column_buffer::allocate_strings_data(rmm::cuda_stream_view stream)
void cudf::io::detail::inline_column_buffer::allocate_strings_data(bool memset_data,
rmm::cuda_stream_view stream)
{
CUDF_EXPECTS(type.id() == type_id::STRING, "allocate_strings_data called for non-string column");
// size + 1 for final offset. _string_data will be initialized later.
_data = create_data(data_type{type_id::INT32}, size + 1, stream, _mr);
_data = create_data(data_type{type_to_id<size_type>()}, size + 1, memset_data, stream, _mr);
}

void cudf::io::detail::inline_column_buffer::create_string_data(size_t num_bytes,
Expand Down Expand Up @@ -93,36 +94,50 @@ void copy_buffer_data(string_policy const& buff, string_policy& new_buff)
template <class string_policy>
void column_buffer_base<string_policy>::create_with_mask(size_type _size,
cudf::mask_state null_mask_state,
bool memset_data,
rmm::cuda_stream_view stream,
rmm::device_async_resource_ref mr)
{
size = _size;
_mr = mr;

switch (type.id()) {
case type_id::STRING: static_cast<string_policy*>(this)->allocate_strings_data(stream); break;
case type_id::STRING:
static_cast<string_policy*>(this)->allocate_strings_data(memset_data, stream);
break;

// list columns store a buffer of int32's as offsets to represent
// their individual rows
case type_id::LIST: _data = create_data(data_type{type_id::INT32}, size, stream, _mr); break;
case type_id::LIST:
_data = create_data(data_type{type_to_id<size_type>()}, size, memset_data, stream, _mr);
break;

// struct columns store no data themselves. just validity and children.
case type_id::STRUCT: break;

default: _data = create_data(type, size, stream, _mr); break;
default: _data = create_data(type, size, memset_data, stream, _mr); break;
}
if (is_nullable) {
_null_mask =
cudf::detail::create_null_mask(size, null_mask_state, rmm::cuda_stream_view(stream), _mr);
}
}

template <class string_policy>
void column_buffer_base<string_policy>::create(size_type _size,
bool memset_data,
rmm::cuda_stream_view stream,
rmm::device_async_resource_ref mr)
{
create_with_mask(_size, mask_state::ALL_NULL, memset_data, stream, mr);
}

template <class string_policy>
void column_buffer_base<string_policy>::create(size_type _size,
rmm::cuda_stream_view stream,
rmm::device_async_resource_ref mr)
{
create_with_mask(_size, mask_state::ALL_NULL, stream, mr);
create_with_mask(_size, mask_state::ALL_NULL, true, stream, mr);
}

template <class string_policy>
Expand Down
Loading

0 comments on commit 837dfe5

Please sign in to comment.