Skip to content

Commit

Permalink
Removing thrust::optional on device_uvector until it can be sorted out
Browse files Browse the repository at this point in the history
  • Loading branch information
hyperbolic2346 committed Mar 17, 2021
1 parent 0b766c5 commit 8baa2b7
Showing 1 changed file with 31 additions and 27 deletions.
58 changes: 31 additions & 27 deletions cpp/src/lists/explode.cu
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@ std::unique_ptr<table> build_table(
column_view const& sliced_child,
cudf::device_span<size_type const> gather_map,
thrust::optional<cudf::device_span<size_type const>> explode_col_gather_map,
thrust::optional<rmm::device_uvector<size_type>> position_array,
rmm::device_uvector<size_type> position_array,
bool include_position,
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* mr)
{
Expand Down Expand Up @@ -72,11 +73,11 @@ std::unique_ptr<table> build_table(
->release()[0])
: std::make_unique<column>(sliced_child, stream, mr));

if (position_array) {
size_type position_size = position_array->size();
if (include_position) {
size_type position_size = position_array.size();
columns.insert(columns.begin() + explode_column_idx,
std::make_unique<column>(
data_type(type_to_id<size_type>()), position_size, position_array->release()));
data_type(type_to_id<size_type>()), position_size, position_array.release()));
}

return std::make_unique<table>(std::move(columns));
Expand Down Expand Up @@ -109,12 +110,14 @@ std::unique_ptr<table> explode(table_view const& input_table,
counting_iter + gather_map.size(),
gather_map.begin());

rmm::device_uvector<size_type> empty{0, stream};
return build_table(input_table,
explode_column_idx,
sliced_child,
gather_map,
thrust::nullopt,
thrust::nullopt,
std::move(empty),
false,
stream,
mr);
}
Expand Down Expand Up @@ -162,6 +165,7 @@ std::unique_ptr<table> explode_position(table_view const& input_table,
gather_map,
thrust::nullopt,
std::move(pos),
true,
stream,
mr);
}
Expand Down Expand Up @@ -213,46 +217,46 @@ std::unique_ptr<table> explode_outer(table_view const& input_table,
counting_iter,
counting_iter + sliced_child.size(),
[offsets_minus_one,
gather_map = gather_map.begin(),
explode_col_gather_map = explode_col_gather_map.begin(),
position_array = pos.begin(),
_gather_map = gather_map.begin(),
_explode_col_gather_map = explode_col_gather_map.begin(),
_position_array = pos.begin(),
include_position,
offsets,
null_or_empty_offset = null_or_empty_offset.begin(),
_null_or_empty_offset = null_or_empty_offset.begin(),
null_or_empty,
offset_size = explode_col.offsets().size() - 1] __device__(auto idx) {
auto lb_idx = thrust::distance(
offsets_minus_one,
thrust::lower_bound(
thrust::seq, offsets_minus_one, offsets_minus_one + (offset_size), idx));
auto index_to_write = null_or_empty_offset[lb_idx] + idx;
gather_map[index_to_write] = lb_idx;
explode_col_gather_map[index_to_write] = idx;
auto index_to_write = _null_or_empty_offset[lb_idx] + idx;
_gather_map[index_to_write] = lb_idx;
_explode_col_gather_map[index_to_write] = idx;
if (include_position) {
position_array[index_to_write] = idx - (offsets[lb_idx] - offsets[0]);
_position_array[index_to_write] = idx - (offsets[lb_idx] - offsets[0]);
}
if (null_or_empty[idx]) {
auto invalid_index = null_or_empty_offset[idx] == 0
auto invalid_index = _null_or_empty_offset[idx] == 0
? offsets[idx]
: offsets[idx] + null_or_empty_offset[idx] - 1;
gather_map[invalid_index] = idx;
: offsets[idx] + _null_or_empty_offset[idx] - 1;
_gather_map[invalid_index] = idx;

// negative one to indicate a null value
explode_col_gather_map[invalid_index] = -1;
_explode_col_gather_map[invalid_index] = -1;

if (include_position) { position_array[invalid_index] = 0; }
if (include_position) { _position_array[invalid_index] = 0; }
}
});

return build_table(
input_table,
explode_column_idx,
sliced_child,
gather_map,
explode_col_gather_map,
include_position ? std::move(pos) : thrust::optional<rmm::device_uvector<size_type>>{},
stream,
mr);
return build_table(input_table,
explode_column_idx,
sliced_child,
gather_map,
explode_col_gather_map,
std::move(pos),
include_position,
stream,
mr);
}

} // namespace detail
Expand Down

0 comments on commit 8baa2b7

Please sign in to comment.