Skip to content

Commit

Permalink
fix new quadtree spatial join OOM test
Browse files Browse the repository at this point in the history
  • Loading branch information
trxcllnt committed May 7, 2024
1 parent 781fc7d commit 57354cb
Showing 1 changed file with 63 additions and 42 deletions.
105 changes: 63 additions & 42 deletions cpp/include/cuspatial/detail/join/quadtree_point_in_polygon.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,15 @@
#include <cuspatial/range/multipolygon_range.cuh>
#include <cuspatial/traits.hpp>

#include <rmm/cuda_device.hpp>
#include <rmm/device_uvector.hpp>
#include <rmm/exec_policy.hpp>
#include <rmm/resource_ref.hpp>

#include <thrust/iterator/permutation_iterator.h>
#include <thrust/iterator/transform_iterator.h>
#include <thrust/scan.h>

#include <cstdint>

namespace cuspatial {
namespace detail {

Expand All @@ -57,7 +57,7 @@ struct compute_poly_and_point_indices {
using IndexType = iterator_value_type<QuadOffsetsIterator>;

inline thrust::tuple<IndexType, IndexType> __device__
operator()(IndexType const global_index) const
operator()(std::uint64_t const global_index) const
{
auto const [quad_poly_index, local_point_index] =
get_quad_and_local_point_indices(global_index, point_offsets_begin, point_offsets_end);
Expand Down Expand Up @@ -118,16 +118,22 @@ std::pair<rmm::device_uvector<IndexType>, rmm::device_uvector<IndexType>> quadtr

auto num_poly_quad_pairs = std::distance(poly_indices_first, poly_indices_last);

auto quad_lengths_iter =
thrust::make_permutation_iterator(quadtree.length_begin(), quad_indices_first);
auto quad_lengths_iter = thrust::make_transform_iterator(
thrust::make_permutation_iterator(quadtree.length_begin(), quad_indices_first),
cuda::proclaim_return_type<std::uint64_t>([] __device__(IndexType const& i) -> std::uint64_t {
return static_cast<std::uint64_t>(i);
}));

auto quad_offsets_iter =
thrust::make_permutation_iterator(quadtree.offset_begin(), quad_indices_first);

// Enumerate the point X/Ys using the sorted `point_indices` (from quadtree construction)
auto point_xys_iter = thrust::make_permutation_iterator(points_first, point_indices_first);

// Compute a "local" set of zero-based point offsets from number of points in each quadrant
// Use `num_poly_quad_pairs + 1` as the length so that the last element produced by
// `inclusive_scan` is the total number of points to be tested against any polygon.
rmm::device_uvector<IndexType> local_point_offsets(num_poly_quad_pairs + 1, stream);
rmm::device_uvector<std::uint64_t> local_point_offsets(num_poly_quad_pairs + 1, stream);

// inclusive scan of quad_lengths_iter
thrust::inclusive_scan(rmm::exec_policy(stream),
Expand All @@ -136,52 +142,67 @@ std::pair<rmm::device_uvector<IndexType>, rmm::device_uvector<IndexType>> quadtr
local_point_offsets.begin() + 1);

// Ensure local point offsets starts at 0
IndexType init{0};
std::uint64_t init{0};
local_point_offsets.set_element_async(0, init, stream);

// The last element is the total number of points to test against any polygon.
auto num_total_points = local_point_offsets.back_element(stream);

// Allocate the output polygon and point index pair vectors
rmm::device_uvector<IndexType> poly_indices(num_total_points, stream);
rmm::device_uvector<IndexType> point_indices(num_total_points, stream);

auto poly_and_point_indices =
thrust::make_zip_iterator(poly_indices.begin(), point_indices.begin());
// Process in chunks of *at most* 25% of the free available device memory to avoid OOM'ing
// when allocating additional output space.
auto max_chunk_size =
static_cast<std::uint64_t>(rmm::percent_of_free_device_memory(25) / (sizeof(IndexType) * 2));

// Enumerate the point X/Ys using the sorted `point_indices` (from quadtree construction)
auto point_xys_iter = thrust::make_permutation_iterator(points_first, point_indices_first);
auto num_chunks = std::uint64_t{num_total_points / max_chunk_size + 1};

// Compute the combination of polygon and point index pairs. For each polygon/quadrant pair,
// enumerate pairs of (poly_index, point_index) for each point in each quadrant.
//
// In Python pseudocode:
// ```
// pp_pairs = []
// for polygon, quadrant in pq_pairs:
// for point in quadrant:
// pp_pairs.append((polygon, point))
// ```
//
auto global_to_poly_and_point_indices = detail::make_counting_transform_iterator(
0,
detail::compute_poly_and_point_indices{quad_offsets_iter,
local_point_offsets.begin(),
local_point_offsets.end(),
poly_indices_first});

// Compute the number of intersections by removing (poly, point) pairs that don't intersect
auto num_intersections = thrust::distance(
poly_and_point_indices,
thrust::copy_if(rmm::exec_policy(stream),
global_to_poly_and_point_indices,
global_to_poly_and_point_indices + num_total_points,
poly_and_point_indices,
detail::test_poly_point_intersection{point_xys_iter, polygons}));
// Allocate the output polygon and point index pair vectors
rmm::device_uvector<IndexType> poly_indices(0, stream);
rmm::device_uvector<IndexType> point_indices(0, stream);

std::uint64_t num_intersections = 0;

for (std::uint64_t chunk_index = 0, chunk_offset = 0; chunk_index < num_chunks; ++chunk_index) {
auto chunk_size = std::min(max_chunk_size, num_total_points - (chunk_index * max_chunk_size));

poly_indices.resize(num_intersections + chunk_size, stream);
point_indices.resize(num_intersections + chunk_size, stream);

// Compute the combination of polygon and point index pairs. For each polygon/quadrant pair,
// enumerate pairs of (poly_index, point_index) for each point in each quadrant.
//
// In Python pseudocode:
// ```
// pp_pairs = []
// for polygon, quadrant in pq_pairs:
// for point in quadrant:
// pp_pairs.append((polygon, point))
// ```
//
auto global_to_poly_and_point_indices = detail::make_counting_transform_iterator(
chunk_offset,
detail::compute_poly_and_point_indices{quad_offsets_iter,
local_point_offsets.begin(),
local_point_offsets.end(),
poly_indices_first});

auto poly_and_point_indices =
thrust::make_zip_iterator(poly_indices.begin(), point_indices.begin()) + num_intersections;

// Compute the number of intersections by removing (poly, point) pairs that don't intersect
num_intersections += thrust::distance(
poly_and_point_indices,
thrust::copy_if(rmm::exec_policy(stream),
global_to_poly_and_point_indices,
global_to_poly_and_point_indices + chunk_size,
poly_and_point_indices,
detail::test_poly_point_intersection{point_xys_iter, polygons}));

chunk_offset += chunk_size;
}

poly_indices.resize(num_intersections, stream);
poly_indices.shrink_to_fit(stream);
point_indices.resize(num_intersections, stream);
poly_indices.shrink_to_fit(stream);
point_indices.shrink_to_fit(stream);

return std::pair{std::move(poly_indices), std::move(point_indices)};
Expand Down

0 comments on commit 57354cb

Please sign in to comment.