Skip to content

Commit 6726500

Browse files
bottlerfacebook-github-bot
authored andcommitted
simple warning for bin overflow
Summary: Since coarse rasterization on cuda can overflow bins, we detect when this happens for memory safety. See #348 . Also try to print a warning. Reviewed By: patricklabatut Differential Revision: D33065604 fbshipit-source-id: 99b3c576d01b78e6d77776cf1a3e95984506c93a
1 parent d6a12af commit 6726500

File tree

1 file changed

+16
-0
lines changed

1 file changed

+16
-0
lines changed

pytorch3d/csrc/rasterize_coarse/rasterize_coarse.cu

+16
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,22 @@ __global__ void RasterizeCoarseCudaKernel(
183183
// this effectively allocates space in the bin_faces array for the
184184
// elems in the current chunk that fall into this bin.
185185
const int start = atomicAdd(elems_per_bin + elems_per_bin_idx, count);
186+
if (start + count > M) {
187+
// The number of elems in this bin is so big that they won't fit.
188+
// We print a warning using CUDA's printf. This may be invisible
189+
// to notebook users, but apparent to others. It would be nice to
190+
// also have a Python-friendly warning, but it is not obvious
191+
// how to do this without slowing down the normal case.
192+
const char* warning =
193+
"Bin size was too small in the coarse rasterization phase. "
194+
"This caused an overflow, meaning output may be incomplete. "
195+
"To solve, "
196+
"try increasing max_faces_per_bin / max_points_per_bin, "
197+
"decreasing bin_size, "
198+
"or setting bin_size to -1 to use the naive rasterization.";
199+
printf(warning);
200+
continue;
201+
}
186202

187203
// Now loop over the binmask and write the active bits for this bin
188204
// out to bin_faces.

0 commit comments

Comments
 (0)