Skip to content

Commit cc35b66

Browse files
committed
implement the voxelization backward pass
1 parent 7a4a1df commit cc35b66

File tree

1 file changed

+24
-6
lines changed

1 file changed

+24
-6
lines changed

models/ndnetpp/nd.py

+24-6
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333

3434
class VoxelizerFunction(torch.autograd.Function):
3535
@staticmethod
36-
def forward(ctx, input: torch.Tensor, num_desired_dists: int, voxel_size: float) -> Tuple[torch.Tensor]:
36+
def forward(ctx, input: torch.Tensor, num_desired_dists: int, voxel_size: float) -> torch.Tensor:
3737
"""
3838
Voxelize the input point cloud.
3939
@@ -49,10 +49,13 @@ def forward(ctx, input: torch.Tensor, num_desired_dists: int, voxel_size: float)
4949
# estimate the normal distributions
5050
start = time.time()
5151
# normal distributions shaped (batch_size, voxels_x, voxels_y, voxels_z, 12)
52-
dists, _, min_coords, _ = nd_utils.normal_distributions.estimate_normal_distributions_with_size(input, voxel_size)
52+
dists, _, min_coords, n_voxels = nd_utils.normal_distributions.estimate_normal_distributions_with_size(input, voxel_size)
5353
end = time.time()
5454
print(f"Normal distributions estimation time {dists.device}: {end - start}s - {(end-start)*1000}ms - {1.0 / (end-start)}Hz")
5555

56+
# get the voxel indices of the input point cloud
57+
voxel_idxs_pcd = nd_utils.voxelization.metric_to_voxel_space(input, voxel_size, n_voxels, min_coords)
58+
5659
# randomly sample the input point cloud
5760
sampled_pcd, sampled_idx = nd_utils.point_clouds.random_sample_point_cloud(input, num_desired_dists)
5861

@@ -65,6 +68,9 @@ def forward(ctx, input: torch.Tensor, num_desired_dists: int, voxel_size: float)
6568
# get the normal distributions at the indices
6669
filtered_dists = dists[batch_idxs, neighborhood_idxs[..., 0], neighborhood_idxs[..., 1], neighborhood_idxs[..., 2]]
6770

71+
# save the context
72+
ctx.save_for_backward(voxel_idxs_pcd, filtered_dists, sampled_idx, neighborhood_idxs)
73+
6874
# return the filtered normal distributions
6975
return filtered_dists
7076

@@ -74,14 +80,26 @@ def backward(ctx, dists_grad: torch.Tensor):
7480
Voxelization backward pass.
7581
7682
Args:
77-
dists_grad (torch.Tensor): gradients of the distributions (voxels) losses (N1).
83+
dists_grad (torch.Tensor): gradients of the distributions (voxels) losses shaped (batch_size, n_dists, 12)
7884
7985
Returns:
80-
torch.Tensor: gradients propagated to the points corresponding to each voxel (N).
86+
torch.Tensor: gradients propagated to the points corresponding to each voxel shaped (batch_size, n_points, 3).
8187
"""
8288

83-
# TODO: distribute the voxel gradients to the corresponding points
84-
raise NotImplementedError("VoxelizerFunction.backward is not implemented.")
89+
# retrieve the saved tensors
90+
voxel_idxs_pcd, out_dists, sampled_idx, neighborhood_idxs = ctx.saved_tensors
91+
92+
# sum the last dimension (normal distribution) of the gradients
93+
dists_grad = dists_grad.sum(dim=-1)
94+
95+
# create a mask where each point's voxel index matches the neighborhood index
96+
mask = (voxel_idxs_pcd.unsqueeze(2) == neighborhood_idxs.unsqueeze(1)).all(dim=-1)
97+
98+
# broadcast the gradients to the points
99+
input_grad = torch.zeros_like(voxel_idxs_pcd, dtype=dists_grad.dtype)
100+
input_grad += (mask.float() * dists_grad.unsqueeze(1)).sum(dim=2).unsqueeze(-1)
101+
102+
return input_grad, None, None
85103

86104
class Voxelizer(nn.Module):
87105
"""

0 commit comments

Comments
 (0)