diff --git a/rasterize_points.cu b/rasterize_points.cu index e625c19e..2a66f631 100644 --- a/rasterize_points.cu +++ b/rasterize_points.cu @@ -75,11 +75,10 @@ RasterizeGaussiansCUDA( torch::Tensor radii = torch::full({P}, 0, means3D.options().dtype(torch::kInt32)); - torch::Device device(torch::kCUDA); - torch::TensorOptions options(torch::kByte); - torch::Tensor geomBuffer = torch::empty({0}, options.device(device)); - torch::Tensor binningBuffer = torch::empty({0}, options.device(device)); - torch::Tensor imgBuffer = torch::empty({0}, options.device(device)); + torch::TensorOptions byte_opts = means3D.options().dtype(torch::kByte); + torch::Tensor geomBuffer = torch::empty({0}, byte_opts); + torch::Tensor binningBuffer = torch::empty({0}, byte_opts); + torch::Tensor imgBuffer = torch::empty({0}, byte_opts); std::function geomFunc = resizeFunctional(geomBuffer); std::function binningFunc = resizeFunctional(binningBuffer); std::function imgFunc = resizeFunctional(imgBuffer);