Skip to content

Commit

Permalink
Use non-synchronising scan (#5560)
Browse files Browse the repository at this point in the history
  • Loading branch information
RAMitchell authored Apr 20, 2020
1 parent d6d1035 commit b2827a8
Showing 1 changed file with 22 additions and 12 deletions.
34 changes: 22 additions & 12 deletions src/tree/gpu_hist/row_partitioner.cu
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ struct WriteResultsFunctor {
common::Span<RowPartitioner::RowIndexT> ridx_out;
int64_t* d_left_count;

__device__ int operator()(const IndexFlagTuple& x) {
__device__ IndexFlagTuple operator()(const IndexFlagTuple& x) {
// the ex_scan_result represents how many rows have been assigned to left
// node so far during scan.
int scatter_address;
Expand All @@ -56,10 +56,18 @@ struct WriteResultsFunctor {
ridx_out[scatter_address] = ridx_in[x.idx];

// Discard
return 0;
return {};
}
};

// Change the value type of thrust discard iterator so we can use it with cub
class DiscardOverload : public thrust::discard_iterator<IndexFlagTuple> {
public:
using value_type = IndexFlagTuple; // NOLINT
};

// Implement partitioning via single scan operation using transform output to
// write the result
void RowPartitioner::SortPosition(common::Span<bst_node_t> position,
common::Span<bst_node_t> position_out,
common::Span<RowIndexT> ridx,
Expand All @@ -68,19 +76,21 @@ void RowPartitioner::SortPosition(common::Span<bst_node_t> position,
int64_t* d_left_count, cudaStream_t stream) {
WriteResultsFunctor write_results{left_nidx, position, position_out,
ridx, ridx_out, d_left_count};
auto discard_write_iterator = thrust::make_transform_output_iterator(
thrust::discard_iterator<int>(), write_results);
auto discard_write_iterator =
thrust::make_transform_output_iterator(DiscardOverload(), write_results);
auto counting = thrust::make_counting_iterator(0llu);
auto input_iterator = dh::MakeTransformIterator<IndexFlagTuple>(
thrust::make_counting_iterator(0llu), [=] __device__(size_t idx) {
counting, [=] __device__(size_t idx) {
return IndexFlagTuple{idx, position[idx] == left_nidx};
});
dh::XGBCachingDeviceAllocator<char> alloc;
thrust::inclusive_scan(thrust::cuda::par(alloc).on(stream), input_iterator,
input_iterator + position.size(),
discard_write_iterator,
[=] __device__(IndexFlagTuple a, IndexFlagTuple b) {
return IndexFlagTuple{b.idx, a.flag + b.flag};
});
size_t temp_bytes = 0;
cub::DeviceScan::InclusiveScan(nullptr, temp_bytes, input_iterator,
discard_write_iterator, IndexFlagOp(),
position.size(), stream);
dh::TemporaryArray<int8_t> temp(temp_bytes);
cub::DeviceScan::InclusiveScan(temp.data().get(), temp_bytes, input_iterator,
discard_write_iterator, IndexFlagOp(),
position.size(), stream);
}

RowPartitioner::RowPartitioner(int device_idx, size_t num_rows)
Expand Down

0 comments on commit b2827a8

Please sign in to comment.