Skip to content

Commit

Permalink
Fix quadtree spatial join OOMs on large numbers of input polygons (#1381
Browse files Browse the repository at this point in the history
)

Followup to #1346.

* Fixes some typos/omissions in types and CMake.
* Adds a new test that OOMs when quadtree_point_in_polygon is passed too many input polygons.
* Fixes quadtree spatial join to handle overflow while counting and more conservatively allocate output buffers.

Fixes #890.

* [Failing test run](https://github.com/rapidsai/cuspatial/actions/runs/8979838628/job/24662981350#step:7:840)
* [Passing test run](https://github.com/rapidsai/cuspatial/actions/runs/8981106226/job/24666403165#step:7:840)

Authors:
  - Paul Taylor (https://github.com/trxcllnt)

Approvers:
  - Mark Harris (https://github.com/harrism)
  - Michael Wang (https://github.com/isVoid)

URL: #1381
  • Loading branch information
trxcllnt authored May 24, 2024
1 parent eff6753 commit 8840189
Show file tree
Hide file tree
Showing 8 changed files with 271 additions and 47 deletions.
23 changes: 21 additions & 2 deletions cpp/benchmarks/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#=============================================================================
# Copyright (c) 2019-2021, NVIDIA CORPORATION.
# Copyright (c) 2019-2024, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -23,6 +23,17 @@ add_library(cuspatial_benchmark_common OBJECT

target_compile_features(cuspatial_benchmark_common PUBLIC cxx_std_17 cuda_std_17)

set_target_properties(cuspatial_benchmark_common
PROPERTIES RUNTIME_OUTPUT_DIRECTORY "$<BUILD_INTERFACE:${CUSPATIAL_BINARY_DIR}/benchmarks>"
INSTALL_RPATH "\$ORIGIN/../../../lib"
CXX_STANDARD 17
CXX_STANDARD_REQUIRED ON
CUDA_STANDARD 17
CUDA_STANDARD_REQUIRED ON
POSITION_INDEPENDENT_CODE ON
INTERFACE_POSITION_INDEPENDENT_CODE ON
)

target_link_libraries(cuspatial_benchmark_common
PUBLIC benchmark::benchmark
cudf::cudftestutil
Expand All @@ -43,6 +54,10 @@ function(ConfigureBench CMAKE_BENCH_NAME)
set_target_properties(${CMAKE_BENCH_NAME}
PROPERTIES RUNTIME_OUTPUT_DIRECTORY "$<BUILD_INTERFACE:${CUSPATIAL_BINARY_DIR}/benchmarks>"
INSTALL_RPATH "\$ORIGIN/../../../lib"
CXX_STANDARD 17
CXX_STANDARD_REQUIRED ON
CUDA_STANDARD 17
CUDA_STANDARD_REQUIRED ON
)
target_link_libraries(${CMAKE_BENCH_NAME} PRIVATE benchmark::benchmark_main cuspatial_benchmark_common)
install(
Expand All @@ -61,7 +76,11 @@ function(ConfigureNVBench CMAKE_BENCH_NAME)
${CMAKE_BENCH_NAME}
PROPERTIES RUNTIME_OUTPUT_DIRECTORY "$<BUILD_INTERFACE:${CUSPATIAL_BINARY_DIR}/benchmarks>"
INSTALL_RPATH "\$ORIGIN/../../../lib"
)
CXX_STANDARD 17
CXX_STANDARD_REQUIRED ON
CUDA_STANDARD 17
CUDA_STANDARD_REQUIRED ON
)
target_link_libraries(
${CMAKE_BENCH_NAME} PRIVATE cuspatial_benchmark_common nvbench::main
)
Expand Down
123 changes: 85 additions & 38 deletions cpp/include/cuspatial/detail/join/quadtree_point_in_polygon.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,16 @@
#include <cuspatial/range/multipolygon_range.cuh>
#include <cuspatial/traits.hpp>

#include <rmm/cuda_device.hpp>
#include <rmm/device_uvector.hpp>
#include <rmm/exec_policy.hpp>
#include <rmm/resource_ref.hpp>

#include <thrust/iterator/permutation_iterator.h>
#include <thrust/iterator/transform_iterator.h>
#include <thrust/scan.h>

#include <cstdint>
#include <limits>

namespace cuspatial {
namespace detail {
Expand All @@ -57,7 +59,7 @@ struct compute_poly_and_point_indices {
using IndexType = iterator_value_type<QuadOffsetsIterator>;

inline thrust::tuple<IndexType, IndexType> __device__
operator()(IndexType const global_index) const
operator()(std::uint64_t const global_index) const
{
auto const [quad_poly_index, local_point_index] =
get_quad_and_local_point_indices(global_index, point_offsets_begin, point_offsets_end);
Expand Down Expand Up @@ -118,16 +120,26 @@ std::pair<rmm::device_uvector<IndexType>, rmm::device_uvector<IndexType>> quadtr

auto num_poly_quad_pairs = std::distance(poly_indices_first, poly_indices_last);

auto quad_lengths_iter =
thrust::make_permutation_iterator(quadtree.length_begin(), quad_indices_first);
// The quadtree length is an iterator of uint32_t, but we have to transform into uint64_t values
// so the thrust::inclusive_scan accumulates into uint64_t outputs. Changing the output iterator
// to uint64_t isn't sufficient to achieve this behavior.
auto quad_lengths_iter = thrust::make_transform_iterator(
thrust::make_permutation_iterator(quadtree.length_begin(), quad_indices_first),
cuda::proclaim_return_type<std::uint64_t>([] __device__(IndexType const& i) -> std::uint64_t {
return static_cast<std::uint64_t>(i);
}));

auto quad_offsets_iter =
thrust::make_permutation_iterator(quadtree.offset_begin(), quad_indices_first);

// Compute a "local" set of zero-based point offsets from number of points in each quadrant
// Compute a "local" set of zero-based point offsets from the number of points in each quadrant.
//
// Use `num_poly_quad_pairs + 1` as the length so that the last element produced by
// `inclusive_scan` is the total number of points to be tested against any polygon.
rmm::device_uvector<IndexType> local_point_offsets(num_poly_quad_pairs + 1, stream);
//
// Accumulate into uint64_t, because the prefix sums can overflow the size of uint32_t
// when testing a large number of polygons against a large quadtree.
rmm::device_uvector<std::uint64_t> local_point_offsets(num_poly_quad_pairs + 1, stream);

// inclusive scan of quad_lengths_iter
thrust::inclusive_scan(rmm::exec_policy(stream),
Expand All @@ -136,21 +148,27 @@ std::pair<rmm::device_uvector<IndexType>, rmm::device_uvector<IndexType>> quadtr
local_point_offsets.begin() + 1);

// Ensure local point offsets starts at 0
IndexType init{0};
std::uint64_t init{0};
local_point_offsets.set_element_async(0, init, stream);

// The last element is the total number of points to test against any polygon.
auto num_total_points = local_point_offsets.back_element(stream);

// Allocate the output polygon and point index pair vectors
rmm::device_uvector<IndexType> poly_indices(num_total_points, stream);
rmm::device_uvector<IndexType> point_indices(num_total_points, stream);

auto poly_and_point_indices =
thrust::make_zip_iterator(poly_indices.begin(), point_indices.begin());

// Enumerate the point X/Ys using the sorted `point_indices` (from quadtree construction)
auto point_xys_iter = thrust::make_permutation_iterator(points_first, point_indices_first);
// The largest supported input size for thrust::count_if/copy_if is INT32_MAX.
// This functor iterates over the input space and processes up to INT32_MAX elements at a time.
std::uint64_t max_points_to_test = std::numeric_limits<std::int32_t>::max();
auto count_in_chunks = [&](auto const& func) {
std::uint64_t memo{};
for (std::uint64_t offset{0}; offset < num_total_points; offset += max_points_to_test) {
memo += func(memo, offset, std::min(max_points_to_test, num_total_points - offset));
}
return memo;
};

detail::test_poly_point_intersection test_poly_point_pair{
// Enumerate the point X/Ys using the sorted `point_indices` (from quadtree construction)
thrust::make_permutation_iterator(points_first, point_indices_first),
polygons};

// Compute the combination of polygon and point index pairs. For each polygon/quadrant pair,
// enumerate pairs of (poly_index, point_index) for each point in each quadrant.
Expand All @@ -163,28 +181,57 @@ std::pair<rmm::device_uvector<IndexType>, rmm::device_uvector<IndexType>> quadtr
// pp_pairs.append((polygon, point))
// ```
//
auto global_to_poly_and_point_indices = detail::make_counting_transform_iterator(
0,
detail::compute_poly_and_point_indices{quad_offsets_iter,
local_point_offsets.begin(),
local_point_offsets.end(),
poly_indices_first});

// Compute the number of intersections by removing (poly, point) pairs that don't intersect
auto num_intersections = thrust::distance(
poly_and_point_indices,
thrust::copy_if(rmm::exec_policy(stream),
global_to_poly_and_point_indices,
global_to_poly_and_point_indices + num_total_points,
poly_and_point_indices,
detail::test_poly_point_intersection{point_xys_iter, polygons}));

poly_indices.resize(num_intersections, stream);
poly_indices.shrink_to_fit(stream);
point_indices.resize(num_intersections, stream);
point_indices.shrink_to_fit(stream);

return std::pair{std::move(poly_indices), std::move(point_indices)};
auto global_to_poly_and_point_indices = [&](auto offset = 0) {
return detail::make_counting_transform_iterator(
offset,
detail::compute_poly_and_point_indices{quad_offsets_iter,
local_point_offsets.begin(),
local_point_offsets.end(),
poly_indices_first});
};

auto run_quadtree_point_in_polygon = [&](auto output_size) {
// Allocate the output polygon and point index pair vectors
rmm::device_uvector<IndexType> poly_indices(output_size, stream);
rmm::device_uvector<IndexType> point_indices(output_size, stream);

auto num_intersections = count_in_chunks([&](auto memo, auto offset, auto size) {
auto poly_and_point_indices =
thrust::make_zip_iterator(poly_indices.begin(), point_indices.begin()) + memo;
// Remove (poly, point) pairs that don't intersect
return thrust::distance(poly_and_point_indices,
thrust::copy_if(rmm::exec_policy(stream),
global_to_poly_and_point_indices(offset),
global_to_poly_and_point_indices(offset) + size,
poly_and_point_indices,
test_poly_point_pair));
});

if (num_intersections < output_size) {
poly_indices.resize(num_intersections, stream);
point_indices.resize(num_intersections, stream);
poly_indices.shrink_to_fit(stream);
point_indices.shrink_to_fit(stream);
}

return std::pair{std::move(poly_indices), std::move(point_indices)};
};

try {
// First attempt to run the hit test assuming allocating space for all possible intersections
// fits into the available memory.
return run_quadtree_point_in_polygon(num_total_points);
} catch (rmm::out_of_memory const&) {
// If we OOM the first time, pre-compute the number of hits and allocate only that amount of
// space for the output buffers. This halves performance, but it should at least return valid
// results.
return run_quadtree_point_in_polygon(count_in_chunks([&](auto memo, auto offset, auto size) {
return thrust::count_if(rmm::exec_policy(stream),
global_to_poly_and_point_indices(offset),
global_to_poly_and_point_indices(offset) + size,
test_poly_point_pair);
}));
}
}

} // namespace cuspatial
3 changes: 2 additions & 1 deletion cpp/include/cuspatial/detail/range/multilinestring_range.cuh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2022-2023, NVIDIA CORPORATION.
* Copyright (c) 2022-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -74,6 +74,7 @@ template <typename GeometryIterator, typename PartIterator, typename VecIterator
class multilinestring_range;

template <typename GeometryIterator, typename PartIterator, typename VecIterator>
CUSPATIAL_HOST_DEVICE
multilinestring_range<GeometryIterator, PartIterator, VecIterator>::multilinestring_range(
GeometryIterator geometry_begin,
GeometryIterator geometry_end,
Expand Down
6 changes: 3 additions & 3 deletions cpp/include/cuspatial/geometry/box.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2022, NVIDIA CORPORATION.
* Copyright (c) 2022-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -40,9 +40,9 @@ class alignas(sizeof(Vertex)) box {

private:
/**
* @brief Output stream operator for `vec_2d<T>` for human-readable formatting
* @brief Output stream operator for `box<T>` for human-readable formatting
*/
friend std::ostream& operator<<(std::ostream& os, cuspatial::box<T> const& b)
friend std::ostream& operator<<(std::ostream& os, cuspatial::box<T, Vertex> const& b)
{
return os << "{" << b.v1 << ", " << b.v2 << "}";
}
Expand Down
7 changes: 5 additions & 2 deletions cpp/include/cuspatial_test/geometry_generator.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@

#include <ranger/ranger.hpp>

#include <cmath>

namespace cuspatial {
namespace test {

Expand Down Expand Up @@ -399,8 +401,9 @@ auto generate_multipoint_array(multipoint_generator_parameter<T> params,
std::size_t{0},
params.num_points_per_multipoints);

auto engine_x = deterministic_engine(params.num_points());
auto engine_y = deterministic_engine(2 * params.num_points());
auto golden_ratio = (1 + std::sqrt(T{5})) / 2;
auto engine_x = deterministic_engine(golden_ratio * params.num_points());
auto engine_y = deterministic_engine((1 / golden_ratio) * params.num_points());

auto x_dist = make_uniform_dist(params.lower_left.x, params.upper_right.x);
auto y_dist = make_uniform_dist(params.lower_left.y, params.upper_right.y);
Expand Down
11 changes: 11 additions & 0 deletions cpp/include/cuspatial_test/vector_factories.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -363,6 +363,17 @@ class multipoint_array {
_geometry_offsets.begin(), _geometry_offsets.end(), _coordinates.begin(), _coordinates.end()};
}

/**
* @brief Copy the offset arrays to host.
*/
auto to_host() const
{
auto geometry_offsets = cuspatial::test::to_host<geometry_t>(_geometry_offsets);
auto coordinate_offsets = cuspatial::test::to_host<coord_t>(_coordinates);

return std::tuple{geometry_offsets, coordinate_offsets};
}

/// Release ownership
auto release() { return std::pair{std::move(_geometry_offsets), std::move(_coordinates)}; }

Expand Down
5 changes: 4 additions & 1 deletion cpp/tests/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#=============================================================================
# Copyright (c) 2019-2023, NVIDIA CORPORATION.
# Copyright (c) 2019-2024, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -158,6 +158,9 @@ ConfigureTest(JOIN_POINT_IN_POLYGON_SMALL_TEST_EXP
ConfigureTest(JOIN_POINT_IN_POLYGON_LARGE_TEST_EXP
join/quadtree_point_in_polygon_test_large.cu)

ConfigureTest(JOIN_POINT_IN_POLYGON_OOM_TEST_EXP
join/quadtree_point_in_polygon_test_oom.cu)

ConfigureTest(JOIN_POINT_TO_LINESTRING_SMALL_TEST_EXP
join/quadtree_point_to_nearest_linestring_test_small.cu)

Expand Down
Loading

0 comments on commit 8840189

Please sign in to comment.