Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Accelerate Instant-NGP inference #197

Merged
merged 22 commits into from
May 3, 2023
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions nerfacc/cuda/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ def call_cuda(*args, **kwargs):
# grid
ray_aabb_intersect = _make_lazy_cuda_func("ray_aabb_intersect")
traverse_grids = _make_lazy_cuda_func("traverse_grids")
traverse_grids_test = _make_lazy_cuda_func("traverse_grids_test")

# scan
exclusive_sum_by_key = _make_lazy_cuda_func("exclusive_sum_by_key")
Expand Down
153 changes: 126 additions & 27 deletions nerfacc/cuda/csrc/grid.cu
Original file line number Diff line number Diff line change
Expand Up @@ -46,37 +46,51 @@ __global__ void traverse_grids_kernel(
float *far_planes,
float step_size,
float cone_angle,
int max_samples_per_ray,
// outputs
bool first_pass,
PackedRaySegmentsSpec intervals,
PackedRaySegmentsSpec samples)
PackedRaySegmentsSpec samples,
// for test time traverse_
int64_t *ray_mask_id)
{
float eps = 1e-6f;

// parallelize over rays
for (int32_t tid = blockIdx.x * blockDim.x + threadIdx.x; tid < n_rays; tid += blockDim.x * gridDim.x)
{
// skip rays that are empty.
if (intervals.chunk_cnts != nullptr)
if (!first_pass && intervals.chunk_cnts[tid] == 0) continue;
if (samples.chunk_cnts != nullptr)
if (!first_pass && samples.chunk_cnts[tid] == 0) continue;
bool test_time = ray_mask_id != nullptr;

float near_plane = near_planes[tid];
float far_plane = far_planes[tid];

int32_t tid_t = tid;
int64_t chunk_start, chunk_start_bin;
if (!first_pass) {
if (!test_time) {
// skip rays that are empty.
if (intervals.chunk_cnts != nullptr)
chunk_start = intervals.chunk_starts[tid];
if (!first_pass && intervals.chunk_cnts[tid] == 0) continue;
if (samples.chunk_cnts != nullptr)
chunk_start_bin = samples.chunk_starts[tid];
if (!first_pass && samples.chunk_cnts[tid] == 0) continue;

if (!first_pass) {
if (intervals.chunk_cnts != nullptr)
chunk_start = intervals.chunk_starts[tid];
if (samples.chunk_cnts != nullptr)
chunk_start_bin = samples.chunk_starts[tid];
}
} else {
chunk_start = tid * max_samples_per_ray * 2;
chunk_start_bin = tid * max_samples_per_ray;
// ray_mask_id stores the original ray id for each test time ray.
tid_t = ray_mask_id[tid];
}
float near_plane = near_planes[tid];
float far_plane = far_planes[tid];

SingleRaySpec ray = SingleRaySpec(
rays_o + tid * 3, rays_d + tid * 3, near_plane, far_plane);
rays_o + tid_t * 3, rays_d + tid_t * 3, near_plane, far_plane);

int32_t base_hits = tid * n_grids;
int32_t base_t_sorted = tid * n_grids * 2;
int32_t base_hits = tid_t * n_grids;
int32_t base_t_sorted = tid_t * n_grids * 2;

// loop over all intersections along the ray.
int64_t n_intervals = 0;
Expand Down Expand Up @@ -137,8 +151,9 @@ __global__ void traverse_grids_kernel(
// delta.x, delta.y, delta.z, step_index.x, step_index.y, step_index.z
// );


const int3 overflow_index = final_index + step_index;
while (true) {
while (n_samples < max_samples_per_ray) {
float t_traverse = min(tdist.x, min(tdist.y, tdist.z));
t_traverse = fminf(t_traverse, this_tmax);
int64_t cell_id = (
Expand All @@ -162,7 +177,7 @@ __global__ void traverse_grids_kernel(
continuous = false;
} else {
// this cell is not empty, so we need to traverse it.
while (true) {
while (n_samples < max_samples_per_ray) {
float t_next;
if (step_size <= 0.0f) {
t_next = t_traverse;
Expand All @@ -171,7 +186,6 @@ __global__ void traverse_grids_kernel(
if (t_last + dt * 0.5f >= t_traverse) break;
t_next = t_last + dt;
}

// writeout the interval.
if (intervals.chunk_cnts != nullptr) {
if (!continuous) {
Expand Down Expand Up @@ -206,7 +220,11 @@ __global__ void traverse_grids_kernel(
if (!first_pass) {
int64_t idx = chunk_start_bin + n_samples;
samples.vals[idx] = (t_next + t_last) * 0.5f;
samples.ray_indices[idx] = tid;
samples.ray_indices[idx] = tid_t;
samples.is_valid[idx] = true;
if (test_time) {
near_planes[tid] = t_next;
}
}
n_samples++;
}
Expand All @@ -226,18 +244,18 @@ __global__ void traverse_grids_kernel(
break;
}
}
if (n_samples >= max_samples_per_ray) break;
}

if (first_pass) {
if (first_pass || test_time) {
if (intervals.chunk_cnts != nullptr)
intervals.chunk_cnts[tid] = n_intervals;
if (samples.chunk_cnts != nullptr)
if (samples.chunk_cnts != nullptr){
samples.chunk_cnts[tid] = n_samples;
}
}
}
}


__global__ void ray_aabb_intersect_kernel(
const int32_t n_rays, float *rays_o, float *rays_d, float near, float far,
const int32_t n_aabbs, float *aabbs,
Expand Down Expand Up @@ -291,7 +309,8 @@ std::vector<RaySegmentsSpec> traverse_grids(
const float step_size,
const float cone_angle,
const bool compute_intervals,
const bool compute_samples)
const bool compute_samples,
const int max_samples_per_ray)
{
DEVICE_GUARD(rays_o);

Expand Down Expand Up @@ -325,6 +344,7 @@ std::vector<RaySegmentsSpec> traverse_grids(
intervals.memalloc_cnts(n_rays, rays_o.options(), false);
if (compute_samples)
samples.memalloc_cnts(n_rays, rays_o.options(), false);

device::traverse_grids_kernel<<<blocks, threads, 0, stream>>>(
// rays
n_rays,
Expand All @@ -344,16 +364,18 @@ std::vector<RaySegmentsSpec> traverse_grids(
far_planes.data_ptr<float>(), // [n_rays]
step_size,
cone_angle,
max_samples_per_ray,
// outputs
true,
device::PackedRaySegmentsSpec(intervals),
device::PackedRaySegmentsSpec(samples));
device::PackedRaySegmentsSpec(samples),
nullptr);

// second pass to record the segments.
if (compute_intervals)
intervals.memalloc_data(true, true);
intervals.memalloc_data_from_chunk(true, true);
if (compute_samples)
samples.memalloc_data(false, false);
samples.memalloc_data_from_chunk(false, false, true);
device::traverse_grids_kernel<<<blocks, threads, 0, stream>>>(
// rays
n_rays,
Expand All @@ -373,15 +395,92 @@ std::vector<RaySegmentsSpec> traverse_grids(
far_planes.data_ptr<float>(), // [n_rays]
step_size,
cone_angle,
max_samples_per_ray,
// outputs
false,
device::PackedRaySegmentsSpec(intervals),
device::PackedRaySegmentsSpec(samples));
device::PackedRaySegmentsSpec(samples),
nullptr);

return {intervals, samples};
}


std::vector<RaySegmentsSpec> traverse_grids_test(
liruilong940607 marked this conversation as resolved.
Show resolved Hide resolved
// rays
const torch::Tensor ray_mask_id, // [n_rays_chunk]
const torch::Tensor rays_o, // [n_rays, 3]
const torch::Tensor rays_d, // [n_rays, 3]
// grids
const torch::Tensor binaries, // [n_grids, resx, resy, resz]
const torch::Tensor aabbs, // [n_grids, 6]
// intersections
const torch::Tensor t_sorted, // [n_rays, n_grids]
const torch::Tensor t_indices, // [n_rays, n_grids]
const torch::Tensor hits, // [n_rays, n_grids]
// torch::Tensor continuous_resume, // [n_rays]
liruilong940607 marked this conversation as resolved.
Show resolved Hide resolved
// options
torch::Tensor near_planes,
const torch::Tensor far_planes,
const float step_size,
const float cone_angle,
const int max_samples_per_ray)
{
DEVICE_GUARD(rays_o);

int32_t n_rays = ray_mask_id.size(0);
int32_t n_grids = binaries.size(0);
int3 resolution = make_int3(binaries.size(1), binaries.size(2), binaries.size(3));

at::cuda::CUDAStream stream = at::cuda::getCurrentCUDAStream();
int32_t max_threads = 512;
int32_t max_blocks = 65535;
dim3 threads = dim3(min(max_threads, n_rays));
dim3 blocks = dim3(min(max_blocks, ceil_div<int32_t>(n_rays, threads.x)));

// outputs
int total_steps = n_rays * max_samples_per_ray;

// dummy output
RaySegmentsSpec intervals, samples;

intervals.memalloc_cnts(n_rays, rays_o.options(), true);
samples.memalloc_cnts(n_rays, rays_o.options(), true);

intervals.memalloc_data(total_steps*2, true, true);
samples.memalloc_data(total_steps, false, true, true);

device::traverse_grids_kernel<<<blocks, threads, 0, stream>>>(
// rays
n_rays,
rays_o.data_ptr<float>(), // [n_rays, 3]
rays_d.data_ptr<float>(), // [n_rays, 3]
// grids
n_grids,
resolution,
binaries.data_ptr<bool>(), // [n_grids, resx, resy, resz]
aabbs.data_ptr<float>(), // [n_grids, 6]
// sorted intersections
hits.data_ptr<bool>(), // [n_rays, n_grids]
t_sorted.data_ptr<float>(), // [n_rays, n_grids * 2]
t_indices.data_ptr<int64_t>(), // [n_rays, n_grids * 2]
// options
near_planes.data_ptr<float>(), // [n_rays]
far_planes.data_ptr<float>(), // [n_rays]
step_size,
cone_angle,
max_samples_per_ray,
// outputs
false,
device::PackedRaySegmentsSpec(intervals),
device::PackedRaySegmentsSpec(samples),
ray_mask_id.data_ptr<int64_t>() // [n_rays]
);
samples.compute_chunk_start();

return {intervals, samples};
}

std::vector<torch::Tensor> ray_aabb_intersect(
const torch::Tensor rays_o, // [n_rays, 3]
const torch::Tensor rays_d, // [n_rays, 3]
Expand Down
55 changes: 40 additions & 15 deletions nerfacc/cuda/csrc/include/data_spec.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ struct RaySegmentsSpec {
torch::Tensor ray_indices; // [n_edges]
torch::Tensor is_left; // [n_edges] have n_bins true values
torch::Tensor is_right; // [n_edges] have n_bins true values
torch::Tensor is_valid; // [n_edges] have n_bins true values

inline void check() {
CHECK_INPUT(vals);
Expand Down Expand Up @@ -80,6 +81,11 @@ struct RaySegmentsSpec {
TORCH_CHECK(is_right.ndimension() == 1);
TORCH_CHECK(vals.numel() == is_right.numel());
}
if (is_valid.defined()) {
CHECK_INPUT(is_valid);
TORCH_CHECK(is_valid.ndimension() == 1);
TORCH_CHECK(vals.numel() == is_valid.numel());
}
}

inline void memalloc_cnts(int32_t n_rays, at::TensorOptions options, bool zero_init = true) {
Expand All @@ -91,30 +97,49 @@ struct RaySegmentsSpec {
}
}

inline int64_t memalloc_data(bool alloc_masks = true, bool zero_init = true) {
inline void memalloc_data(int32_t size, bool alloc_masks = true, bool zero_init = true, bool alloc_valid = false) {
TORCH_CHECK(chunk_cnts.defined());
TORCH_CHECK(!chunk_starts.defined());
TORCH_CHECK(!vals.defined());

torch::Tensor cumsum = torch::cumsum(chunk_cnts, 0, chunk_cnts.scalar_type());
int64_t n_edges = cumsum[-1].item<int64_t>();

chunk_starts = cumsum - chunk_cnts;

if (zero_init) {
vals = torch::zeros({n_edges}, chunk_cnts.options().dtype(torch::kFloat32));
ray_indices = torch::zeros({n_edges}, chunk_cnts.options().dtype(torch::kLong));
vals = torch::zeros({size}, chunk_cnts.options().dtype(torch::kFloat32));
ray_indices = torch::zeros({size}, chunk_cnts.options().dtype(torch::kLong));
if (alloc_masks) {
is_left = torch::zeros({n_edges}, chunk_cnts.options().dtype(torch::kBool));
is_right = torch::zeros({n_edges}, chunk_cnts.options().dtype(torch::kBool));
is_left = torch::zeros({size}, chunk_cnts.options().dtype(torch::kBool));
is_right = torch::zeros({size}, chunk_cnts.options().dtype(torch::kBool));
}
} else {
vals = torch::empty({n_edges}, chunk_cnts.options().dtype(torch::kFloat32));
ray_indices = torch::empty({n_edges}, chunk_cnts.options().dtype(torch::kLong));
vals = torch::empty({size}, chunk_cnts.options().dtype(torch::kFloat32));
ray_indices = torch::empty({size}, chunk_cnts.options().dtype(torch::kLong));
if (alloc_masks) {
is_left = torch::empty({n_edges}, chunk_cnts.options().dtype(torch::kBool));
is_right = torch::empty({n_edges}, chunk_cnts.options().dtype(torch::kBool));
is_left = torch::empty({size}, chunk_cnts.options().dtype(torch::kBool));
is_right = torch::empty({size}, chunk_cnts.options().dtype(torch::kBool));
}
}
if (alloc_valid) {
is_valid = torch::zeros({size}, chunk_cnts.options().dtype(torch::kBool));
}
}

inline int64_t memalloc_data_from_chunk(bool alloc_masks = true, bool zero_init = true, bool alloc_valid = false) {
TORCH_CHECK(chunk_cnts.defined());
TORCH_CHECK(!chunk_starts.defined());

torch::Tensor cumsum = torch::cumsum(chunk_cnts, 0, chunk_cnts.scalar_type());
int64_t n_edges = cumsum[-1].item<int64_t>();

chunk_starts = cumsum - chunk_cnts;
memalloc_data(n_edges, alloc_masks, zero_init, alloc_valid);
return 1;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

return void?

}

// compute the chunk_start from chunk_cnts
inline int64_t compute_chunk_start() {
TORCH_CHECK(chunk_cnts.defined());
TORCH_CHECK(!chunk_starts.defined());

torch::Tensor cumsum = torch::cumsum(chunk_cnts, 0, chunk_cnts.scalar_type());
chunk_starts = cumsum - chunk_cnts;
return 1;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

return void?

}
};
2 changes: 2 additions & 0 deletions nerfacc/cuda/csrc/include/data_spec_packed.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ struct PackedRaySegmentsSpec {
ray_indices(spec.ray_indices.defined() ? spec.ray_indices.data_ptr<int64_t>() : nullptr),
is_left(spec.is_left.defined() ? spec.is_left.data_ptr<bool>() : nullptr),
is_right(spec.is_right.defined() ? spec.is_right.data_ptr<bool>() : nullptr),
is_valid(spec.is_valid.defined() ? spec.is_valid.data_ptr<bool>() : nullptr),
// for dimensions
n_edges(spec.vals.defined() ? spec.vals.numel() : 0),
n_rays(spec.chunk_cnts.defined() ? spec.chunk_cnts.size(0) : 0), // for flattened tensor
Expand All @@ -31,6 +32,7 @@ struct PackedRaySegmentsSpec {
int64_t* ray_indices;
bool* is_left;
bool* is_right;
bool* is_valid;

int64_t n_edges;
int32_t n_rays;
Expand Down
Loading