Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Accelerate Instant-NGP inference #197

Merged
merged 22 commits into from
May 3, 2023
Merged
Show file tree
Hide file tree
Changes from 21 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
268 changes: 180 additions & 88 deletions nerfacc/cuda/csrc/grid.cu
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,50 @@ inline __device__ float _calc_dt(
return clamp(t * cone_angle, dt_min, dt_max);
}

/* Ray traversal within multiple voxel grids.

About rays:
Each ray is defined by its origin (rays_o) and unit direction (rays_d). We also allows
a optional boolen ray mask (rays_mask) to indicate whether we want to skip some rays.

About voxel grids:
We support ray traversal through one or more voxel grids (n_grids). Each grid is defined
by an axis-aligned AABB (aabbs), and a binary occupancy grid (binaries) with resolution of
{resx, resy, resz}. Currently, we assume all grids have the same resolution. Note the ordering
of the grids is important when there are overlapping grids, because we assume the grid in front
has higher priority when examing occupancy status (e.g., the first grid's occupancy status
will overwrite the second grid's occupancy status if they overlap).

About ray grid intersections:
We require the ray grid intersections to be precomputed and sorted. Specifically, if hit, each
ray-grid pair has two intersections, one for entering the grid and one for leaving the grid.
For multiple grids, there are in total 2 * n_grids intersections for each ray. The intersections
are sorted by the distance to the ray origin (t_sorted). We take a boolen array (hits) to indicate
whether each ray-grid pair is hit. We also need a int64 array (t_indices) to indicate the grid id
(0-index) for each intersection.

About ray traversal:
The ray is traversed through the grids in the order of the sorted intersections. We allows pre-ray
near and far planes (near_planes, far_planes) to be specified. Early termination can be controlled by
setting the maximum traverse steps via traverse_steps_limit. We also allow an optional step size
(step_size) to be specified. If step_size <= 0.0, we will record the steps of the ray pass through
each voxel cell. Otherwise, we will use the step_size to march through the grids. When step_size > 0.0,
we also allow a cone angle (cone_angle) to be provides, to linearly increase the step size as the ray
goes further away from the origin (see _calc_dt()). cone_angle should be always >= 0.0, and 0.0
means uniform marching with step_size.

About outputs:
The traversal intervals and samples are stored in `intervals` and `samples` respectively. Additionally,
we also return where the traversal actually terminates (terminate_planes). This is useful when
traverse_steps_limit is set (traverse_steps_limit > 0) as the ray may not reach the far plane or the
boundary of the grids.
*/
__global__ void traverse_grids_kernel(
// rays
int32_t n_rays,
float *rays_o, // [n_rays, 3]
float *rays_d, // [n_rays, 3]
bool *rays_mask, // [n_rays]
// grids
int32_t n_grids,
int3 resolution,
Expand All @@ -42,20 +81,24 @@ __global__ void traverse_grids_kernel(
float *t_sorted, // [n_rays, n_grids * 2]
int64_t *t_indices, // [n_rays, n_grids * 2]
// options
float *near_planes,
float *far_planes,
float *near_planes, // [n_rays]
float *far_planes, // [n_rays]
float step_size,
float cone_angle,
int32_t traverse_steps_limit,
// outputs
bool first_pass,
PackedRaySegmentsSpec intervals,
PackedRaySegmentsSpec samples)
PackedRaySegmentsSpec samples,
float *terminate_planes)
{
float eps = 1e-6f;

// parallelize over rays
for (int32_t tid = blockIdx.x * blockDim.x + threadIdx.x; tid < n_rays; tid += blockDim.x * gridDim.x)
{
if (rays_mask != nullptr && !rays_mask[tid]) continue;

// skip rays that are empty.
if (intervals.chunk_cnts != nullptr)
if (!first_pass && intervals.chunk_cnts[tid] == 0) continue;
Expand Down Expand Up @@ -138,7 +181,7 @@ __global__ void traverse_grids_kernel(
// );

const int3 overflow_index = final_index + step_index;
while (true) {
while (traverse_steps_limit <= 0 || n_samples < traverse_steps_limit) {
float t_traverse = min(tdist.x, min(tdist.y, tdist.z));
t_traverse = fminf(t_traverse, this_tmax);
int64_t cell_id = (
Expand All @@ -162,7 +205,7 @@ __global__ void traverse_grids_kernel(
continuous = false;
} else {
// this cell is not empty, so we need to traverse it.
while (true) {
while (traverse_steps_limit <= 0 || n_samples < traverse_steps_limit) {
float t_next;
if (step_size <= 0.0f) {
t_next = t_traverse;
Expand Down Expand Up @@ -207,10 +250,11 @@ __global__ void traverse_grids_kernel(
int64_t idx = chunk_start_bin + n_samples;
samples.vals[idx] = (t_next + t_last) * 0.5f;
samples.ray_indices[idx] = tid;
samples.is_valid[idx] = true;
}
n_samples++;
}

n_samples++;
continuous = true;
t_last = t_next;
if (t_next >= t_traverse) break;
Expand All @@ -227,17 +271,16 @@ __global__ void traverse_grids_kernel(
}
}
}

if (first_pass) {
if (intervals.chunk_cnts != nullptr)
intervals.chunk_cnts[tid] = n_intervals;
if (samples.chunk_cnts != nullptr)
samples.chunk_cnts[tid] = n_samples;
}
if (terminate_planes != nullptr)
terminate_planes[tid] = t_last;

if (intervals.chunk_cnts != nullptr)
intervals.chunk_cnts[tid] = n_intervals;
if (samples.chunk_cnts != nullptr)
samples.chunk_cnts[tid] = n_samples;
}
}


__global__ void ray_aabb_intersect_kernel(
const int32_t n_rays, float *rays_o, float *rays_d, float near, float far,
const int32_t n_aabbs, float *aabbs,
Expand Down Expand Up @@ -274,26 +317,33 @@ __global__ void ray_aabb_intersect_kernel(
} // namespace


std::vector<RaySegmentsSpec> traverse_grids(
std::tuple<RaySegmentsSpec, RaySegmentsSpec, torch::Tensor> traverse_grids(
// rays
const torch::Tensor rays_o, // [n_rays, 3]
const torch::Tensor rays_d, // [n_rays, 3]
const torch::Tensor rays_mask, // [n_rays]
// grids
const torch::Tensor binaries, // [n_grids, resx, resy, resz]
const torch::Tensor aabbs, // [n_grids, 6]
// intersections
const torch::Tensor t_mins, // [n_rays, n_grids]
const torch::Tensor t_maxs, // [n_rays, n_grids]
const torch::Tensor t_sorted, // [n_rays, n_grids]
const torch::Tensor t_indices, // [n_rays, n_grids]
const torch::Tensor hits, // [n_rays, n_grids]
// options
const torch::Tensor near_planes,
const torch::Tensor far_planes,
const float step_size,
const float cone_angle,
const bool compute_intervals,
const bool compute_samples)
const bool compute_samples,
const bool compute_terminate_planes,
const int32_t traverse_steps_limit, // <= 0 means no limit
const bool over_allocate) // over allocate the memory for intervals and samples
{
DEVICE_GUARD(rays_o);
if (over_allocate) {
TORCH_CHECK(traverse_steps_limit > 0, "traverse_steps_limit must be > 0 when over_allocate is true");
}

int32_t n_rays = rays_o.size(0);
int32_t n_grids = binaries.size(0);
Expand All @@ -305,80 +355,122 @@ std::vector<RaySegmentsSpec> traverse_grids(
dim3 threads = dim3(min(max_threads, n_rays));
dim3 blocks = dim3(min(max_blocks, ceil_div<int32_t>(n_rays, threads.x)));

// Sort the intersections. [n_rays, n_grids * 2]
torch::Tensor t_sorted, t_indices;
if (n_grids > 1) {
std::tie(t_sorted, t_indices) = torch::sort(torch::cat({t_mins, t_maxs}, -1), -1);
}
else {
t_sorted = torch::cat({t_mins, t_maxs}, -1);
t_indices = torch::arange(
0, n_grids * 2, t_mins.options().dtype(torch::kLong)
).expand({n_rays, n_grids * 2}).contiguous();
}

// outputs
RaySegmentsSpec intervals, samples;
torch::Tensor terminate_planes;
if (compute_terminate_planes)
terminate_planes = torch::empty({n_rays}, rays_o.options());

if (over_allocate) {
// over allocate the memory so that we can traverse the grids in a single pass.
if (compute_intervals) {
intervals.chunk_cnts = torch::full({n_rays}, traverse_steps_limit * 2, rays_o.options().dtype(torch::kLong)) * rays_mask;
intervals.memalloc_data_from_chunk(true, true);
}
if (compute_samples) {
samples.chunk_cnts = torch::full({n_rays}, traverse_steps_limit, rays_o.options().dtype(torch::kLong)) * rays_mask;
samples.memalloc_data_from_chunk(false, true, true);
}

// first pass to count the number of segments along each ray.
if (compute_intervals)
intervals.memalloc_cnts(n_rays, rays_o.options(), false);
if (compute_samples)
samples.memalloc_cnts(n_rays, rays_o.options(), false);
device::traverse_grids_kernel<<<blocks, threads, 0, stream>>>(
// rays
n_rays,
rays_o.data_ptr<float>(), // [n_rays, 3]
rays_d.data_ptr<float>(), // [n_rays, 3]
// grids
n_grids,
resolution,
binaries.data_ptr<bool>(), // [n_grids, resx, resy, resz]
aabbs.data_ptr<float>(), // [n_grids, 6]
// sorted intersections
hits.data_ptr<bool>(), // [n_rays, n_grids]
t_sorted.data_ptr<float>(), // [n_rays, n_grids * 2]
t_indices.data_ptr<int64_t>(), // [n_rays, n_grids * 2]
// options
near_planes.data_ptr<float>(), // [n_rays]
far_planes.data_ptr<float>(), // [n_rays]
step_size,
cone_angle,
// outputs
true,
device::PackedRaySegmentsSpec(intervals),
device::PackedRaySegmentsSpec(samples));
device::traverse_grids_kernel<<<blocks, threads, 0, stream>>>(
// rays
n_rays,
rays_o.data_ptr<float>(), // [n_rays, 3]
rays_d.data_ptr<float>(), // [n_rays, 3]
rays_mask.data_ptr<bool>(), // [n_rays]
// grids
n_grids,
resolution,
binaries.data_ptr<bool>(), // [n_grids, resx, resy, resz]
aabbs.data_ptr<float>(), // [n_grids, 6]
// sorted intersections
hits.data_ptr<bool>(), // [n_rays, n_grids]
t_sorted.data_ptr<float>(), // [n_rays, n_grids * 2]
t_indices.data_ptr<int64_t>(), // [n_rays, n_grids * 2]
// options
near_planes.data_ptr<float>(), // [n_rays]
far_planes.data_ptr<float>(), // [n_rays]
step_size,
cone_angle,
traverse_steps_limit,
// outputs
false,
device::PackedRaySegmentsSpec(intervals),
device::PackedRaySegmentsSpec(samples),
compute_terminate_planes ? terminate_planes.data_ptr<float>() : nullptr);

// update the chunk starts with the actual chunk_cnts from traversal.
intervals.compute_chunk_start();
samples.compute_chunk_start();
} else {
// To allocate the accurate memory we need to traverse the grids twice.
// The first pass is to count the number of segments along each ray.
// The second pass is to fill the segments.
if (compute_intervals)
intervals.chunk_cnts = torch::empty({n_rays}, rays_o.options().dtype(torch::kLong));
if (compute_samples)
samples.chunk_cnts = torch::empty({n_rays}, rays_o.options().dtype(torch::kLong));
device::traverse_grids_kernel<<<blocks, threads, 0, stream>>>(
// rays
n_rays,
rays_o.data_ptr<float>(), // [n_rays, 3]
rays_d.data_ptr<float>(), // [n_rays, 3]
nullptr, /* rays_mask */
// grids
n_grids,
resolution,
binaries.data_ptr<bool>(), // [n_grids, resx, resy, resz]
aabbs.data_ptr<float>(), // [n_grids, 6]
// sorted intersections
hits.data_ptr<bool>(), // [n_rays, n_grids]
t_sorted.data_ptr<float>(), // [n_rays, n_grids * 2]
t_indices.data_ptr<int64_t>(), // [n_rays, n_grids * 2]
// options
near_planes.data_ptr<float>(), // [n_rays]
far_planes.data_ptr<float>(), // [n_rays]
step_size,
cone_angle,
traverse_steps_limit,
// outputs
true,
device::PackedRaySegmentsSpec(intervals),
device::PackedRaySegmentsSpec(samples),
nullptr); /* terminate_planes */

// second pass to record the segments.
if (compute_intervals)
intervals.memalloc_data_from_chunk(true, true);
if (compute_samples)
samples.memalloc_data_from_chunk(false, false, true);
device::traverse_grids_kernel<<<blocks, threads, 0, stream>>>(
// rays
n_rays,
rays_o.data_ptr<float>(), // [n_rays, 3]
rays_d.data_ptr<float>(), // [n_rays, 3]
nullptr, /* rays_mask */
// grids
n_grids,
resolution,
binaries.data_ptr<bool>(), // [n_grids, resx, resy, resz]
aabbs.data_ptr<float>(), // [n_grids, 6]
// sorted intersections
hits.data_ptr<bool>(), // [n_rays, n_grids]
t_sorted.data_ptr<float>(), // [n_rays, n_grids * 2]
t_indices.data_ptr<int64_t>(), // [n_rays, n_grids * 2]
// options
near_planes.data_ptr<float>(), // [n_rays]
far_planes.data_ptr<float>(), // [n_rays]
step_size,
cone_angle,
traverse_steps_limit,
// outputs
false,
device::PackedRaySegmentsSpec(intervals),
device::PackedRaySegmentsSpec(samples),
compute_terminate_planes ? terminate_planes.data_ptr<float>() : nullptr);
}

// second pass to record the segments.
if (compute_intervals)
intervals.memalloc_data(true, true);
if (compute_samples)
samples.memalloc_data(false, false);
device::traverse_grids_kernel<<<blocks, threads, 0, stream>>>(
// rays
n_rays,
rays_o.data_ptr<float>(), // [n_rays, 3]
rays_d.data_ptr<float>(), // [n_rays, 3]
// grids
n_grids,
resolution,
binaries.data_ptr<bool>(), // [n_grids, resx, resy, resz]
aabbs.data_ptr<float>(), // [n_grids, 6]
// sorted intersections
hits.data_ptr<bool>(), // [n_rays, n_grids]
t_sorted.data_ptr<float>(), // [n_rays, n_grids * 2]
t_indices.data_ptr<int64_t>(), // [n_rays, n_grids * 2]
// options
near_planes.data_ptr<float>(), // [n_rays]
far_planes.data_ptr<float>(), // [n_rays]
step_size,
cone_angle,
// outputs
false,
device::PackedRaySegmentsSpec(intervals),
device::PackedRaySegmentsSpec(samples));

return {intervals, samples};
return {intervals, samples, terminate_planes};
}


Expand Down
Loading