@@ -16,26 +16,17 @@ std::vector<at::Tensor> three_nn(at::Tensor unknowns, at::Tensor knows)
1616 CHECK_IS_FLOAT (unknowns);
1717 CHECK_IS_FLOAT (knows);
1818
19- if (unknowns.type ().is_cuda ())
20- {
21- CHECK_CUDA (knows);
22- }
19+ CHECK_CUDA (knows);
20+ CHECK_CUDA (unknowns);
2321
2422 at::Tensor idx = torch::zeros ({unknowns.size (0 ), unknowns.size (1 ), 3 },
2523 at::device (unknowns.device ()).dtype (at::ScalarType::Int));
2624 at::Tensor dist2 = torch::zeros ({unknowns.size (0 ), unknowns.size (1 ), 3 },
2725 at::device (unknowns.device ()).dtype (at::ScalarType::Float));
2826
29- if (unknowns.type ().is_cuda ())
30- {
31- three_nn_kernel_wrapper (unknowns.size (0 ), unknowns.size (1 ), knows.size (1 ),
32- unknowns.DATA_PTR <float >(), knows.DATA_PTR <float >(),
33- dist2.DATA_PTR <float >(), idx.DATA_PTR <int >());
34- }
35- else
36- {
37- TORCH_CHECK (false , " CPU not supported" );
38- }
27+ three_nn_kernel_wrapper (unknowns.size (0 ), unknowns.size (1 ), knows.size (1 ),
28+ unknowns.DATA_PTR <float >(), knows.DATA_PTR <float >(),
29+ dist2.DATA_PTR <float >(), idx.DATA_PTR <int >());
3930
4031 return {dist2, idx};
4132}
@@ -49,25 +40,15 @@ at::Tensor three_interpolate(at::Tensor points, at::Tensor idx, at::Tensor weigh
4940 CHECK_IS_INT (idx);
5041 CHECK_IS_FLOAT (weight);
5142
52- if (points.type ().is_cuda ())
53- {
54- CHECK_CUDA (idx);
55- CHECK_CUDA (weight);
56- }
43+ CHECK_CUDA (idx);
44+ CHECK_CUDA (weight);
5745
5846 at::Tensor output = torch::zeros ({points.size (0 ), points.size (1 ), idx.size (1 )},
5947 at::device (points.device ()).dtype (at::ScalarType::Float));
6048
61- if (points.type ().is_cuda ())
62- {
63- three_interpolate_kernel_wrapper (points.size (0 ), points.size (1 ), points.size (2 ),
64- idx.size (1 ), points.DATA_PTR <float >(), idx.DATA_PTR <int >(),
65- weight.DATA_PTR <float >(), output.DATA_PTR <float >());
66- }
67- else
68- {
69- TORCH_CHECK (false , " CPU not supported" );
70- }
49+ three_interpolate_kernel_wrapper (points.size (0 ), points.size (1 ), points.size (2 ), idx.size (1 ),
50+ points.DATA_PTR <float >(), idx.DATA_PTR <int >(),
51+ weight.DATA_PTR <float >(), output.DATA_PTR <float >());
7152
7253 return output;
7354}
@@ -80,26 +61,16 @@ at::Tensor three_interpolate_grad(at::Tensor grad_out, at::Tensor idx, at::Tenso
8061 CHECK_IS_FLOAT (grad_out);
8162 CHECK_IS_INT (idx);
8263 CHECK_IS_FLOAT (weight);
83-
84- if (grad_out.type ().is_cuda ())
85- {
86- CHECK_CUDA (idx);
87- CHECK_CUDA (weight);
88- }
64+ CHECK_CUDA (idx);
65+ CHECK_CUDA (weight);
66+ CHECK_CUDA (grad_out);
8967
9068 at::Tensor output = torch::zeros ({grad_out.size (0 ), grad_out.size (1 ), m},
9169 at::device (grad_out.device ()).dtype (at::ScalarType::Float));
9270
93- if (grad_out.type ().is_cuda ())
94- {
95- three_interpolate_grad_kernel_wrapper (grad_out.size (0 ), grad_out.size (1 ), grad_out.size (2 ),
96- m, grad_out.DATA_PTR <float >(), idx.DATA_PTR <int >(),
97- weight.DATA_PTR <float >(), output.DATA_PTR <float >());
98- }
99- else
100- {
101- TORCH_CHECK (false , " CPU not supported" );
102- }
71+ three_interpolate_grad_kernel_wrapper (grad_out.size (0 ), grad_out.size (1 ), grad_out.size (2 ), m,
72+ grad_out.DATA_PTR <float >(), idx.DATA_PTR <int >(),
73+ weight.DATA_PTR <float >(), output.DATA_PTR <float >());
10374
10475 return output;
10576}
0 commit comments