Skip to content

Commit

Permalink
add mark_invisible_cells in occ_grid
Browse files Browse the repository at this point in the history
add test mode for traverse_grids
  • Loading branch information
Linyou committed Apr 19, 2023
1 parent 09b43b1 commit 064379f
Show file tree
Hide file tree
Showing 8 changed files with 321 additions and 76 deletions.
1 change: 1 addition & 0 deletions nerfacc/cuda/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ def call_cuda(*args, **kwargs):
# grid
ray_aabb_intersect = _make_lazy_cuda_func("ray_aabb_intersect")
traverse_grids = _make_lazy_cuda_func("traverse_grids")
traverse_grids_test = _make_lazy_cuda_func("traverse_grids_test")

# scan
exclusive_sum_by_key = _make_lazy_cuda_func("exclusive_sum_by_key")
Expand Down
158 changes: 128 additions & 30 deletions nerfacc/cuda/csrc/grid.cu
Original file line number Diff line number Diff line change
Expand Up @@ -46,37 +46,50 @@ __global__ void traverse_grids_kernel(
float *far_planes,
float step_size,
float cone_angle,
int max_samples_per_ray,
// outputs
bool first_pass,
PackedRaySegmentsSpec intervals,
PackedRaySegmentsSpec samples)
PackedRaySegmentsSpec samples,
// for test time traverse_
int64_t *ray_mask_id)
{
float eps = 1e-6f;

// parallelize over rays
for (int32_t tid = blockIdx.x * blockDim.x + threadIdx.x; tid < n_rays; tid += blockDim.x * gridDim.x)
{
// skip rays that are empty.
if (intervals.chunk_cnts != nullptr)
if (!first_pass && intervals.chunk_cnts[tid] == 0) continue;
if (samples.chunk_cnts != nullptr)
if (!first_pass && samples.chunk_cnts[tid] == 0) continue;
bool test_time = ray_mask_id != nullptr;

float near_plane = near_planes[tid];
float far_plane = far_planes[tid];

int32_t tid_t = tid;
int64_t chunk_start, chunk_start_bin;
if (!first_pass) {
if (!test_time) {
// skip rays that are empty.
if (intervals.chunk_cnts != nullptr)
chunk_start = intervals.chunk_starts[tid];
if (!first_pass && intervals.chunk_cnts[tid] == 0) continue;
if (samples.chunk_cnts != nullptr)
chunk_start_bin = samples.chunk_starts[tid];
if (!first_pass && samples.chunk_cnts[tid] == 0) continue;

if (!first_pass) {
if (intervals.chunk_cnts != nullptr)
chunk_start = intervals.chunk_starts[tid];
if (samples.chunk_cnts != nullptr)
chunk_start_bin = samples.chunk_starts[tid];
}
} else {
chunk_start = tid * max_samples_per_ray * 2;
chunk_start_bin = tid * max_samples_per_ray;
tid_t = ray_mask_id[tid];
}
float near_plane = near_planes[tid];
float far_plane = far_planes[tid];

SingleRaySpec ray = SingleRaySpec(
rays_o + tid * 3, rays_d + tid * 3, near_plane, far_plane);
rays_o + tid_t * 3, rays_d + tid_t * 3, near_plane, far_plane);

int32_t base_hits = tid * n_grids;
int32_t base_t_sorted = tid * n_grids * 2;
int32_t base_hits = tid_t * n_grids;
int32_t base_t_sorted = tid_t * n_grids * 2;

// loop over all intersections along the ray.
int64_t n_intervals = 0;
Expand Down Expand Up @@ -137,8 +150,9 @@ __global__ void traverse_grids_kernel(
// delta.x, delta.y, delta.z, step_index.x, step_index.y, step_index.z
// );


const int3 overflow_index = final_index + step_index;
while (true) {
while (n_samples < max_samples_per_ray) {
float t_traverse = min(tdist.x, min(tdist.y, tdist.z));
t_traverse = fminf(t_traverse, this_tmax);
int64_t cell_id = (
Expand All @@ -162,7 +176,7 @@ __global__ void traverse_grids_kernel(
continuous = false;
} else {
// this cell is not empty, so we need to traverse it.
while (true) {
while (n_samples < max_samples_per_ray) {
float t_next;
if (step_size <= 0.0f) {
t_next = t_traverse;
Expand All @@ -171,29 +185,28 @@ __global__ void traverse_grids_kernel(
if (t_last + dt * 0.5f >= t_traverse) break;
t_next = t_last + dt;
}

// writeout the interval.
if (intervals.chunk_cnts != nullptr) {
if (!continuous) {
if (!first_pass) { // left side of the intervel
int64_t idx = chunk_start + n_intervals;
intervals.vals[idx] = t_last;
intervals.ray_indices[idx] = tid;
intervals.ray_indices[idx] = tid_t;
intervals.is_left[idx] = true;
}
n_intervals++;
if (!first_pass) { // right side of the intervel
int64_t idx = chunk_start + n_intervals;
intervals.vals[idx] = t_next;
intervals.ray_indices[idx] = tid;
intervals.ray_indices[idx] = tid_t;
intervals.is_right[idx] = true;
}
n_intervals++;
} else {
if (!first_pass) { // right side of the intervel
int64_t idx = chunk_start + n_intervals;
intervals.vals[idx] = t_next;
intervals.ray_indices[idx] = tid;
intervals.ray_indices[idx] = tid_t;
intervals.is_left[idx - 1] = true;
intervals.is_right[idx] = true;
}
Expand All @@ -206,7 +219,11 @@ __global__ void traverse_grids_kernel(
if (!first_pass) {
int64_t idx = chunk_start_bin + n_samples;
samples.vals[idx] = (t_next + t_last) * 0.5f;
samples.ray_indices[idx] = tid;
samples.ray_indices[idx] = tid_t;
samples.is_valid[idx] = true;
if (test_time) {
near_planes[tid] = t_next;
}
}
n_samples++;
}
Expand All @@ -226,18 +243,18 @@ __global__ void traverse_grids_kernel(
break;
}
}
if (n_samples >= max_samples_per_ray) break;
}

if (first_pass) {
if (first_pass || test_time) {
if (intervals.chunk_cnts != nullptr)
intervals.chunk_cnts[tid] = n_intervals;
if (samples.chunk_cnts != nullptr)
if (samples.chunk_cnts != nullptr){
samples.chunk_cnts[tid] = n_samples;
}
}
}
}


__global__ void ray_aabb_intersect_kernel(
const int32_t n_rays, float *rays_o, float *rays_d, float near, float far,
const int32_t n_aabbs, float *aabbs,
Expand Down Expand Up @@ -291,7 +308,8 @@ std::vector<RaySegmentsSpec> traverse_grids(
const float step_size,
const float cone_angle,
const bool compute_intervals,
const bool compute_samples)
const bool compute_samples,
const int max_samples_per_ray)
{
DEVICE_GUARD(rays_o);

Expand Down Expand Up @@ -325,6 +343,7 @@ std::vector<RaySegmentsSpec> traverse_grids(
intervals.memalloc_cnts(n_rays, rays_o.options(), false);
if (compute_samples)
samples.memalloc_cnts(n_rays, rays_o.options(), false);

device::traverse_grids_kernel<<<blocks, threads, 0, stream>>>(
// rays
n_rays,
Expand All @@ -344,16 +363,18 @@ std::vector<RaySegmentsSpec> traverse_grids(
far_planes.data_ptr<float>(), // [n_rays]
step_size,
cone_angle,
max_samples_per_ray,
// outputs
true,
device::PackedRaySegmentsSpec(intervals),
device::PackedRaySegmentsSpec(samples));
device::PackedRaySegmentsSpec(samples),
nullptr);

// second pass to record the segments.
if (compute_intervals)
intervals.memalloc_data(true, true);
intervals.memalloc_data_from_chunk(true, true);
if (compute_samples)
samples.memalloc_data(false, false);
samples.memalloc_data_from_chunk(false, false, true);
device::traverse_grids_kernel<<<blocks, threads, 0, stream>>>(
// rays
n_rays,
Expand All @@ -373,15 +394,92 @@ std::vector<RaySegmentsSpec> traverse_grids(
far_planes.data_ptr<float>(), // [n_rays]
step_size,
cone_angle,
max_samples_per_ray,
// outputs
false,
device::PackedRaySegmentsSpec(intervals),
device::PackedRaySegmentsSpec(samples));
device::PackedRaySegmentsSpec(samples),
nullptr);

return {intervals, samples};
}


std::vector<RaySegmentsSpec> traverse_grids_test(
// rays
const torch::Tensor ray_mask_id, // [n_rays_chunk]
const torch::Tensor rays_o, // [n_rays, 3]
const torch::Tensor rays_d, // [n_rays, 3]
// grids
const torch::Tensor binaries, // [n_grids, resx, resy, resz]
const torch::Tensor aabbs, // [n_grids, 6]
// intersections
const torch::Tensor t_sorted, // [n_rays, n_grids]
const torch::Tensor t_indices, // [n_rays, n_grids]
const torch::Tensor hits, // [n_rays, n_grids]
// torch::Tensor continuous_resume, // [n_rays]
// options
torch::Tensor near_planes,
const torch::Tensor far_planes,
const float step_size,
const float cone_angle,
const int max_samples_per_ray)
{
DEVICE_GUARD(rays_o);

int32_t n_rays = ray_mask_id.size(0);
int32_t n_grids = binaries.size(0);
int3 resolution = make_int3(binaries.size(1), binaries.size(2), binaries.size(3));

at::cuda::CUDAStream stream = at::cuda::getCurrentCUDAStream();
int32_t max_threads = 512;
int32_t max_blocks = 65535;
dim3 threads = dim3(min(max_threads, n_rays));
dim3 blocks = dim3(min(max_blocks, ceil_div<int32_t>(n_rays, threads.x)));

// outputs
int total_steps = n_rays * max_samples_per_ray;

// dummy output
RaySegmentsSpec intervals, samples;

intervals.memalloc_cnts(n_rays, rays_o.options(), true);
samples.memalloc_cnts(n_rays, rays_o.options(), true);

intervals.memalloc_data(total_steps*2, true, true);
samples.memalloc_data(total_steps, false, true, true);

device::traverse_grids_kernel<<<blocks, threads, 0, stream>>>(
// rays
n_rays,
rays_o.data_ptr<float>(), // [n_rays, 3]
rays_d.data_ptr<float>(), // [n_rays, 3]
// grids
n_grids,
resolution,
binaries.data_ptr<bool>(), // [n_grids, resx, resy, resz]
aabbs.data_ptr<float>(), // [n_grids, 6]
// sorted intersections
hits.data_ptr<bool>(), // [n_rays, n_grids]
t_sorted.data_ptr<float>(), // [n_rays, n_grids * 2]
t_indices.data_ptr<int64_t>(), // [n_rays, n_grids * 2]
// options
near_planes.data_ptr<float>(), // [n_rays]
far_planes.data_ptr<float>(), // [n_rays]
step_size,
cone_angle,
max_samples_per_ray,
// outputs
false,
device::PackedRaySegmentsSpec(intervals),
device::PackedRaySegmentsSpec(samples),
ray_mask_id.data_ptr<int64_t>() // [n_rays]
);
samples.compute_chunk_start();

return {intervals, samples};
}

std::vector<torch::Tensor> ray_aabb_intersect(
const torch::Tensor rays_o, // [n_rays, 3]
const torch::Tensor rays_d, // [n_rays, 3]
Expand Down
55 changes: 40 additions & 15 deletions nerfacc/cuda/csrc/include/data_spec.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ struct RaySegmentsSpec {
torch::Tensor ray_indices; // [n_edges]
torch::Tensor is_left; // [n_edges] have n_bins true values
torch::Tensor is_right; // [n_edges] have n_bins true values
torch::Tensor is_valid; // [n_edges] have n_bins true values

inline void check() {
CHECK_INPUT(vals);
Expand Down Expand Up @@ -80,6 +81,11 @@ struct RaySegmentsSpec {
TORCH_CHECK(is_right.ndimension() == 1);
TORCH_CHECK(vals.numel() == is_right.numel());
}
if (is_valid.defined()) {
CHECK_INPUT(is_valid);
TORCH_CHECK(is_valid.ndimension() == 1);
TORCH_CHECK(vals.numel() == is_valid.numel());
}
}

inline void memalloc_cnts(int32_t n_rays, at::TensorOptions options, bool zero_init = true) {
Expand All @@ -91,30 +97,49 @@ struct RaySegmentsSpec {
}
}

inline int64_t memalloc_data(bool alloc_masks = true, bool zero_init = true) {
inline void memalloc_data(int32_t size, bool alloc_masks = true, bool zero_init = true, bool alloc_valid = false) {
TORCH_CHECK(chunk_cnts.defined());
TORCH_CHECK(!chunk_starts.defined());
TORCH_CHECK(!vals.defined());

torch::Tensor cumsum = torch::cumsum(chunk_cnts, 0, chunk_cnts.scalar_type());
int64_t n_edges = cumsum[-1].item<int64_t>();

chunk_starts = cumsum - chunk_cnts;

if (zero_init) {
vals = torch::zeros({n_edges}, chunk_cnts.options().dtype(torch::kFloat32));
ray_indices = torch::zeros({n_edges}, chunk_cnts.options().dtype(torch::kLong));
vals = torch::zeros({size}, chunk_cnts.options().dtype(torch::kFloat32));
ray_indices = torch::zeros({size}, chunk_cnts.options().dtype(torch::kLong));
if (alloc_masks) {
is_left = torch::zeros({n_edges}, chunk_cnts.options().dtype(torch::kBool));
is_right = torch::zeros({n_edges}, chunk_cnts.options().dtype(torch::kBool));
is_left = torch::zeros({size}, chunk_cnts.options().dtype(torch::kBool));
is_right = torch::zeros({size}, chunk_cnts.options().dtype(torch::kBool));
}
} else {
vals = torch::empty({n_edges}, chunk_cnts.options().dtype(torch::kFloat32));
ray_indices = torch::empty({n_edges}, chunk_cnts.options().dtype(torch::kLong));
vals = torch::empty({size}, chunk_cnts.options().dtype(torch::kFloat32));
ray_indices = torch::empty({size}, chunk_cnts.options().dtype(torch::kLong));
if (alloc_masks) {
is_left = torch::empty({n_edges}, chunk_cnts.options().dtype(torch::kBool));
is_right = torch::empty({n_edges}, chunk_cnts.options().dtype(torch::kBool));
is_left = torch::empty({size}, chunk_cnts.options().dtype(torch::kBool));
is_right = torch::empty({size}, chunk_cnts.options().dtype(torch::kBool));
}
}
if (alloc_valid) {
is_valid = torch::zeros({size}, chunk_cnts.options().dtype(torch::kBool));
}
}

inline int64_t memalloc_data_from_chunk(bool alloc_masks = true, bool zero_init = true, bool alloc_valid = false) {
TORCH_CHECK(chunk_cnts.defined());
TORCH_CHECK(!chunk_starts.defined());

torch::Tensor cumsum = torch::cumsum(chunk_cnts, 0, chunk_cnts.scalar_type());
int64_t n_edges = cumsum[-1].item<int64_t>();

chunk_starts = cumsum - chunk_cnts;
memalloc_data(n_edges, alloc_masks, zero_init, alloc_valid);
return 1;
}

// compute the chunk_start from chunk_cnts
inline int64_t compute_chunk_start() {
TORCH_CHECK(chunk_cnts.defined());
TORCH_CHECK(!chunk_starts.defined());

torch::Tensor cumsum = torch::cumsum(chunk_cnts, 0, chunk_cnts.scalar_type());
chunk_starts = cumsum - chunk_cnts;
return 1;
}
};
2 changes: 2 additions & 0 deletions nerfacc/cuda/csrc/include/data_spec_packed.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ struct PackedRaySegmentsSpec {
ray_indices(spec.ray_indices.defined() ? spec.ray_indices.data_ptr<int64_t>() : nullptr),
is_left(spec.is_left.defined() ? spec.is_left.data_ptr<bool>() : nullptr),
is_right(spec.is_right.defined() ? spec.is_right.data_ptr<bool>() : nullptr),
is_valid(spec.is_valid.defined() ? spec.is_valid.data_ptr<bool>() : nullptr),
// for dimensions
n_edges(spec.vals.defined() ? spec.vals.numel() : 0),
n_rays(spec.chunk_cnts.defined() ? spec.chunk_cnts.size(0) : 0), // for flattened tensor
Expand All @@ -31,6 +32,7 @@ struct PackedRaySegmentsSpec {
int64_t* ray_indices;
bool* is_left;
bool* is_right;
bool* is_valid;

int64_t n_edges;
int32_t n_rays;
Expand Down
Loading

0 comments on commit 064379f

Please sign in to comment.