33
33
34
34
class VoxelizerFunction (torch .autograd .Function ):
35
35
@staticmethod
36
- def forward (ctx , input : torch .Tensor , num_desired_dists : int , voxel_size : float ) -> Tuple [ torch .Tensor ] :
36
+ def forward (ctx , input : torch .Tensor , num_desired_dists : int , voxel_size : float ) -> torch .Tensor :
37
37
"""
38
38
Voxelize the input point cloud.
39
39
@@ -49,10 +49,13 @@ def forward(ctx, input: torch.Tensor, num_desired_dists: int, voxel_size: float)
49
49
# estimate the normal distributions
50
50
start = time .time ()
51
51
# normal distributions shaped (batch_size, voxels_x, voxels_y, voxels_z, 12)
52
- dists , _ , min_coords , _ = nd_utils .normal_distributions .estimate_normal_distributions_with_size (input , voxel_size )
52
+ dists , _ , min_coords , n_voxels = nd_utils .normal_distributions .estimate_normal_distributions_with_size (input , voxel_size )
53
53
end = time .time ()
54
54
print (f"Normal distributions estimation time { dists .device } : { end - start } s - { (end - start )* 1000 } ms - { 1.0 / (end - start )} Hz" )
55
55
56
+ # get the voxel indices of the input point cloud
57
+ voxel_idxs_pcd = nd_utils .voxelization .metric_to_voxel_space (input , voxel_size , n_voxels , min_coords )
58
+
56
59
# randomly sample the input point cloud
57
60
sampled_pcd , sampled_idx = nd_utils .point_clouds .random_sample_point_cloud (input , num_desired_dists )
58
61
@@ -65,6 +68,9 @@ def forward(ctx, input: torch.Tensor, num_desired_dists: int, voxel_size: float)
65
68
# get the normal distributions at the indices
66
69
filtered_dists = dists [batch_idxs , neighborhood_idxs [..., 0 ], neighborhood_idxs [..., 1 ], neighborhood_idxs [..., 2 ]]
67
70
71
+ # save the context
72
+ ctx .save_for_backward (voxel_idxs_pcd , filtered_dists , sampled_idx , neighborhood_idxs )
73
+
68
74
# return the filtered normal distributions
69
75
return filtered_dists
70
76
@@ -74,14 +80,26 @@ def backward(ctx, dists_grad: torch.Tensor):
74
80
Voxelization backward pass.
75
81
76
82
Args:
77
- dists_grad (torch.Tensor): gradients of the distributions (voxels) losses (N1).
83
+ dists_grad (torch.Tensor): gradients of the distributions (voxels) losses shaped (batch_size, n_dists, 12)
78
84
79
85
Returns:
80
- torch.Tensor: gradients propagated to the points corresponding to each voxel (N ).
86
+ torch.Tensor: gradients propagated to the points corresponding to each voxel shaped (batch_size, n_points, 3 ).
81
87
"""
82
88
83
- # TODO: distribute the voxel gradients to the corresponding points
84
- raise NotImplementedError ("VoxelizerFunction.backward is not implemented." )
89
+ # retrieve the saved tensors
90
+ voxel_idxs_pcd , out_dists , sampled_idx , neighborhood_idxs = ctx .saved_tensors
91
+
92
+ # sum the last dimension (normal distribution) of the gradients
93
+ dists_grad = dists_grad .sum (dim = - 1 )
94
+
95
+ # create a mask where each point's voxel index matches the neighborhood index
96
+ mask = (voxel_idxs_pcd .unsqueeze (2 ) == neighborhood_idxs .unsqueeze (1 )).all (dim = - 1 )
97
+
98
+ # broadcast the gradients to the points
99
+ input_grad = torch .zeros_like (voxel_idxs_pcd , dtype = dists_grad .dtype )
100
+ input_grad += (mask .float () * dists_grad .unsqueeze (1 )).sum (dim = 2 ).unsqueeze (- 1 )
101
+
102
+ return input_grad , None , None
85
103
86
104
class Voxelizer (nn .Module ):
87
105
"""
0 commit comments