diff --git a/fcos_core/csrc/cpu/ROIAlign_cpu.cpp b/fcos_core/csrc/cpu/ROIAlign_cpu.cpp index d35aedf2..822c0bcc 100644 --- a/fcos_core/csrc/cpu/ROIAlign_cpu.cpp +++ b/fcos_core/csrc/cpu/ROIAlign_cpu.cpp @@ -1,5 +1,5 @@ // Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. -#include "cpu/vision.h" +#include "vision.h" // implementation taken from Caffe2 template diff --git a/fcos_core/csrc/cpu/nms_cpu.cpp b/fcos_core/csrc/cpu/nms_cpu.cpp index 1153dea0..5b9c09d4 100644 --- a/fcos_core/csrc/cpu/nms_cpu.cpp +++ b/fcos_core/csrc/cpu/nms_cpu.cpp @@ -1,5 +1,5 @@ // Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. -#include "cpu/vision.h" +#include "vision.h" template diff --git a/fcos_core/csrc/cuda/ROIAlign_cuda.cu b/fcos_core/csrc/cuda/ROIAlign_cuda.cu index 1142fb37..3551d8db 100644 --- a/fcos_core/csrc/cuda/ROIAlign_cuda.cu +++ b/fcos_core/csrc/cuda/ROIAlign_cuda.cu @@ -1,8 +1,8 @@ // Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. #include #include +#include -#include #include #include @@ -272,11 +272,11 @@ at::Tensor ROIAlign_forward_cuda(const at::Tensor& input, auto output_size = num_rois * pooled_height * pooled_width * channels; cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - dim3 grid(std::min(THCCeilDiv((long)output_size, 512L), 4096L)); + dim3 grid(std::min(at::ceil_div((long)output_size, 512L), 4096L)); dim3 block(512); if (output.numel() == 0) { - THCudaCheck(cudaGetLastError()); + AT_CUDA_CHECK(cudaGetLastError()); return output; } @@ -294,7 +294,7 @@ at::Tensor ROIAlign_forward_cuda(const at::Tensor& input, rois.contiguous().data(), output.data()); }); - THCudaCheck(cudaGetLastError()); + AT_CUDA_CHECK(cudaGetLastError()); return output; } @@ -317,12 +317,12 @@ at::Tensor ROIAlign_backward_cuda(const at::Tensor& grad, cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - dim3 grid(std::min(THCCeilDiv((long)grad.numel(), 512L), 4096L)); + dim3 grid(std::min(at::ceil_div((long)grad.numel(), 512L), 4096L)); dim3 block(512); // handle possibly empty gradients if (grad.numel() == 0) { - THCudaCheck(cudaGetLastError()); + AT_CUDA_CHECK(cudaGetLastError()); return grad_input; } @@ -341,6 +341,6 @@ at::Tensor ROIAlign_backward_cuda(const at::Tensor& grad, grad_input.data(), rois.contiguous().data()); }); - THCudaCheck(cudaGetLastError()); + AT_CUDA_CHECK(cudaGetLastError()); return grad_input; } diff --git a/fcos_core/csrc/cuda/ROIPool_cuda.cu b/fcos_core/csrc/cuda/ROIPool_cuda.cu index 8f072ffc..6cbe61b5 100644 --- a/fcos_core/csrc/cuda/ROIPool_cuda.cu +++ b/fcos_core/csrc/cuda/ROIPool_cuda.cu @@ -1,8 +1,8 @@ // Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. #include #include +#include -#include #include #include @@ -126,11 +126,11 @@ std::tuple ROIPool_forward_cuda(const at::Tensor& input, cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - dim3 grid(std::min(THCCeilDiv((long)output_size, 512L), 4096L)); + dim3 grid(std::min(at::ceil_div((long)output_size, 512L), 4096L)); dim3 block(512); if (output.numel() == 0) { - THCudaCheck(cudaGetLastError()); + AT_CUDA_CHECK(cudaGetLastError()); return std::make_tuple(output, argmax); } @@ -148,7 +148,7 @@ std::tuple ROIPool_forward_cuda(const at::Tensor& input, output.data(), argmax.data()); }); - THCudaCheck(cudaGetLastError()); + AT_CUDA_CHECK(cudaGetLastError()); return std::make_tuple(output, argmax); } @@ -173,12 +173,12 @@ at::Tensor ROIPool_backward_cuda(const at::Tensor& grad, cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - dim3 grid(std::min(THCCeilDiv((long)grad.numel(), 512L), 4096L)); + dim3 grid(std::min(at::ceil_div((long)grad.numel(), 512L), 4096L)); dim3 block(512); // handle possibly empty gradients if (grad.numel() == 0) { - THCudaCheck(cudaGetLastError()); + AT_CUDA_CHECK(cudaGetLastError()); return grad_input; } @@ -197,6 +197,6 @@ at::Tensor ROIPool_backward_cuda(const at::Tensor& grad, grad_input.data(), rois.contiguous().data()); }); - THCudaCheck(cudaGetLastError()); + AT_CUDA_CHECK(cudaGetLastError()); return grad_input; } diff --git a/fcos_core/csrc/cuda/SigmoidFocalLoss_cuda.cu b/fcos_core/csrc/cuda/SigmoidFocalLoss_cuda.cu index 0ac6fb5e..ce162510 100644 --- a/fcos_core/csrc/cuda/SigmoidFocalLoss_cuda.cu +++ b/fcos_core/csrc/cuda/SigmoidFocalLoss_cuda.cu @@ -5,7 +5,6 @@ #include #include -#include #include #include @@ -117,11 +116,11 @@ at::Tensor SigmoidFocalLoss_forward_cuda( auto losses_size = num_samples * logits.size(1); cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - dim3 grid(std::min(THCCeilDiv((long)losses_size, 512L), 4096L)); + dim3 grid(std::min(at::ceil_div((long)losses_size, 512L), 4096L)); dim3 block(512); if (losses.numel() == 0) { - THCudaCheck(cudaGetLastError()); + AT_CUDA_CHECK(cudaGetLastError()); return losses; } @@ -136,7 +135,7 @@ at::Tensor SigmoidFocalLoss_forward_cuda( num_samples, losses.data()); }); - THCudaCheck(cudaGetLastError()); + AT_CUDA_CHECK(cudaGetLastError()); return losses; } @@ -161,11 +160,11 @@ at::Tensor SigmoidFocalLoss_backward_cuda( auto d_logits_size = num_samples * logits.size(1); cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - dim3 grid(std::min(THCCeilDiv((long)d_logits_size, 512L), 4096L)); + dim3 grid(std::min(at::ceil_div((long)d_logits_size, 512L), 4096L)); dim3 block(512); if (d_logits.numel() == 0) { - THCudaCheck(cudaGetLastError()); + AT_CUDA_CHECK(cudaGetLastError()); return d_logits; } @@ -182,7 +181,7 @@ at::Tensor SigmoidFocalLoss_backward_cuda( d_logits.data()); }); - THCudaCheck(cudaGetLastError()); + AT_CUDA_CHECK(cudaGetLastError()); return d_logits; } diff --git a/fcos_core/csrc/cuda/deform_conv_cuda.cu b/fcos_core/csrc/cuda/deform_conv_cuda.cu index 74f7d339..31213075 100644 --- a/fcos_core/csrc/cuda/deform_conv_cuda.cu +++ b/fcos_core/csrc/cuda/deform_conv_cuda.cu @@ -4,7 +4,6 @@ #include #include -#include #include #include @@ -69,26 +68,26 @@ void shape_check(at::Tensor input, at::Tensor offset, at::Tensor *gradOutput, int padW, int dilationH, int dilationW, int group, int deformable_group) { - AT_CHECK(weight.ndimension() == 4, + TORCH_CHECK(weight.ndimension() == 4, "4D weight tensor (nOutputPlane,nInputPlane,kH,kW) expected, " "but got: %s", weight.ndimension()); - AT_CHECK(weight.is_contiguous(), "weight tensor has to be contiguous"); + TORCH_CHECK(weight.is_contiguous(), "weight tensor has to be contiguous"); - AT_CHECK(kW > 0 && kH > 0, + TORCH_CHECK(kW > 0 && kH > 0, "kernel size should be greater than zero, but got kH: %d kW: %d", kH, kW); - AT_CHECK((weight.size(2) == kH && weight.size(3) == kW), + TORCH_CHECK((weight.size(2) == kH && weight.size(3) == kW), "kernel size should be consistent with weight, ", "but got kH: %d kW: %d weight.size(2): %d, weight.size(3): %d", kH, kW, weight.size(2), weight.size(3)); - AT_CHECK(dW > 0 && dH > 0, + TORCH_CHECK(dW > 0 && dH > 0, "stride should be greater than zero, but got dH: %d dW: %d", dH, dW); - AT_CHECK( + TORCH_CHECK( dilationW > 0 && dilationH > 0, "dilation should be greater than 0, but got dilationH: %d dilationW: %d", dilationH, dilationW); @@ -104,7 +103,7 @@ void shape_check(at::Tensor input, at::Tensor offset, at::Tensor *gradOutput, dimw++; } - AT_CHECK(ndim == 3 || ndim == 4, "3D or 4D input tensor expected but got: %s", + TORCH_CHECK(ndim == 3 || ndim == 4, "3D or 4D input tensor expected but got: %s", ndim); long nInputPlane = weight.size(1) * group; @@ -116,7 +115,7 @@ void shape_check(at::Tensor input, at::Tensor offset, at::Tensor *gradOutput, long outputWidth = (inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1; - AT_CHECK(nInputPlane % deformable_group == 0, + TORCH_CHECK(nInputPlane % deformable_group == 0, "input channels must divide deformable group size"); if (outputWidth < 1 || outputHeight < 1) @@ -126,27 +125,27 @@ void shape_check(at::Tensor input, at::Tensor offset, at::Tensor *gradOutput, nInputPlane, inputHeight, inputWidth, nOutputPlane, outputHeight, outputWidth); - AT_CHECK(input.size(1) == nInputPlane, + TORCH_CHECK(input.size(1) == nInputPlane, "invalid number of input planes, expected: %d, but got: %d", nInputPlane, input.size(1)); - AT_CHECK((inputHeight >= kH && inputWidth >= kW), + TORCH_CHECK((inputHeight >= kH && inputWidth >= kW), "input image is smaller than kernel"); - AT_CHECK((offset.size(2) == outputHeight && offset.size(3) == outputWidth), + TORCH_CHECK((offset.size(2) == outputHeight && offset.size(3) == outputWidth), "invalid spatial size of offset, expected height: %d width: %d, but " "got height: %d width: %d", outputHeight, outputWidth, offset.size(2), offset.size(3)); - AT_CHECK((offset.size(1) == deformable_group * 2 * kH * kW), + TORCH_CHECK((offset.size(1) == deformable_group * 2 * kH * kW), "invalid number of channels of offset"); if (gradOutput != NULL) { - AT_CHECK(gradOutput->size(dimf) == nOutputPlane, + TORCH_CHECK(gradOutput->size(dimf) == nOutputPlane, "invalid number of gradOutput planes, expected: %d, but got: %d", nOutputPlane, gradOutput->size(dimf)); - AT_CHECK((gradOutput->size(dimh) == outputHeight && + TORCH_CHECK((gradOutput->size(dimh) == outputHeight && gradOutput->size(dimw) == outputWidth), "invalid size of gradOutput, expected height: %d width: %d , but " "got height: %d width: %d", @@ -197,7 +196,7 @@ int deform_conv_forward_cuda(at::Tensor input, at::Tensor weight, long outputHeight = (inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1; - AT_CHECK((offset.size(0) == batchSize), "invalid batch size of offset"); + TORCH_CHECK((offset.size(0) == batchSize), "invalid batch size of offset"); output = output.view({batchSize / im2col_step, im2col_step, nOutputPlane, outputHeight, outputWidth}); @@ -304,7 +303,7 @@ int deform_conv_backward_input_cuda(at::Tensor input, at::Tensor offset, long outputHeight = (inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1; - AT_CHECK((offset.size(0) == batchSize), 3, "invalid batch size of offset"); + TORCH_CHECK((offset.size(0) == batchSize), 3, "invalid batch size of offset"); gradInput = gradInput.view({batchSize, nInputPlane, inputHeight, inputWidth}); columns = at::zeros( {nInputPlane * kW * kH, im2col_step * outputHeight * outputWidth}, @@ -420,7 +419,7 @@ int deform_conv_backward_parameters_cuda( long outputHeight = (inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1; - AT_CHECK((offset.size(0) == batchSize), "invalid batch size of offset"); + TORCH_CHECK((offset.size(0) == batchSize), "invalid batch size of offset"); columns = at::zeros( {nInputPlane * kW * kH, im2col_step * outputHeight * outputWidth}, @@ -501,8 +500,8 @@ void modulated_deform_conv_cuda_forward( const int dilation_w, const int group, const int deformable_group, const bool with_bias) { - AT_CHECK(input.is_contiguous(), "input tensor has to be contiguous"); - AT_CHECK(weight.is_contiguous(), "weight tensor has to be contiguous"); + TORCH_CHECK(input.is_contiguous(), "input tensor has to be contiguous"); + TORCH_CHECK(weight.is_contiguous(), "weight tensor has to be contiguous"); const int batch = input.size(0); const int channels = input.size(1); @@ -583,8 +582,8 @@ void modulated_deform_conv_cuda_backward( int pad_w, int dilation_h, int dilation_w, int group, int deformable_group, const bool with_bias) { - AT_CHECK(input.is_contiguous(), "input tensor has to be contiguous"); - AT_CHECK(weight.is_contiguous(), "weight tensor has to be contiguous"); + TORCH_CHECK(input.is_contiguous(), "input tensor has to be contiguous"); + TORCH_CHECK(weight.is_contiguous(), "weight tensor has to be contiguous"); const int batch = input.size(0); const int channels = input.size(1); diff --git a/fcos_core/csrc/cuda/deform_pool_cuda.cu b/fcos_core/csrc/cuda/deform_pool_cuda.cu index 71f305af..a513d91d 100644 --- a/fcos_core/csrc/cuda/deform_pool_cuda.cu +++ b/fcos_core/csrc/cuda/deform_pool_cuda.cu @@ -8,7 +8,6 @@ #include #include -#include #include #include @@ -39,7 +38,7 @@ void deform_psroi_pooling_cuda_forward( const int output_dim, const int group_size, const int pooled_size, const int part_size, const int sample_per_part, const float trans_std) { - AT_CHECK(input.is_contiguous(), "input tensor has to be contiguous"); + TORCH_CHECK(input.is_contiguous(), "input tensor has to be contiguous"); const int batch = input.size(0); const int channels = input.size(1); @@ -65,8 +64,8 @@ void deform_psroi_pooling_cuda_backward( const int group_size, const int pooled_size, const int part_size, const int sample_per_part, const float trans_std) { - AT_CHECK(out_grad.is_contiguous(), "out_grad tensor has to be contiguous"); - AT_CHECK(input.is_contiguous(), "input tensor has to be contiguous"); + TORCH_CHECK(out_grad.is_contiguous(), "out_grad tensor has to be contiguous"); + TORCH_CHECK(input.is_contiguous(), "input tensor has to be contiguous"); const int batch = input.size(0); const int channels = input.size(1); diff --git a/fcos_core/csrc/cuda/ml_nms.cu b/fcos_core/csrc/cuda/ml_nms.cu index 222c24f9..f6cb8f9f 100644 --- a/fcos_core/csrc/cuda/ml_nms.cu +++ b/fcos_core/csrc/cuda/ml_nms.cu @@ -1,8 +1,9 @@ // Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. #include #include +#include +#include -#include #include #include @@ -66,7 +67,7 @@ __global__ void ml_nms_kernel(const int n_boxes, const float nms_overlap_thresh, t |= 1ULL << i; } } - const int col_blocks = THCCeilDiv(n_boxes, threadsPerBlock); + const int col_blocks = at::ceil_div(n_boxes, threadsPerBlock); dev_mask[cur_box_idx * col_blocks + col_start] = t; } } @@ -81,20 +82,19 @@ at::Tensor ml_nms_cuda(const at::Tensor boxes, float nms_overlap_thresh) { int boxes_num = boxes.size(0); - const int col_blocks = THCCeilDiv(boxes_num, threadsPerBlock); + const int col_blocks = at::ceil_div(boxes_num, threadsPerBlock); scalar_t* boxes_dev = boxes_sorted.data(); - THCState *state = at::globalContext().lazyInitCUDA(); // TODO replace with getTHCState unsigned long long* mask_dev = NULL; //THCudaCheck(THCudaMalloc(state, (void**) &mask_dev, // boxes_num * col_blocks * sizeof(unsigned long long))); - mask_dev = (unsigned long long*) THCudaMalloc(state, boxes_num * col_blocks * sizeof(unsigned long long)); + mask_dev = (unsigned long long*) c10::cuda::CUDACachingAllocator::raw_alloc(boxes_num * col_blocks * sizeof(unsigned long long)); - dim3 blocks(THCCeilDiv(boxes_num, threadsPerBlock), - THCCeilDiv(boxes_num, threadsPerBlock)); + dim3 blocks(at::ceil_div(boxes_num, threadsPerBlock), + at::ceil_div(boxes_num, threadsPerBlock)); dim3 threads(threadsPerBlock); ml_nms_kernel<<>>(boxes_num, nms_overlap_thresh, @@ -102,7 +102,7 @@ at::Tensor ml_nms_cuda(const at::Tensor boxes, float nms_overlap_thresh) { mask_dev); std::vector mask_host(boxes_num * col_blocks); - THCudaCheck(cudaMemcpy(&mask_host[0], + AT_CUDA_CHECK(cudaMemcpy(&mask_host[0], mask_dev, sizeof(unsigned long long) * boxes_num * col_blocks, cudaMemcpyDeviceToHost)); @@ -127,7 +127,7 @@ at::Tensor ml_nms_cuda(const at::Tensor boxes, float nms_overlap_thresh) { } } - THCudaFree(state, mask_dev); + c10::cuda::CUDACachingAllocator::raw_delete(mask_dev); // TODO improve this part return std::get<0>(order_t.index({ keep.narrow(/*dim=*/0, /*start=*/0, /*length=*/num_to_keep).to( diff --git a/fcos_core/csrc/cuda/nms.cu b/fcos_core/csrc/cuda/nms.cu index 833d8523..19c4587d 100644 --- a/fcos_core/csrc/cuda/nms.cu +++ b/fcos_core/csrc/cuda/nms.cu @@ -1,8 +1,9 @@ // Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. #include #include +#include +#include -#include #include #include @@ -61,7 +62,7 @@ __global__ void nms_kernel(const int n_boxes, const float nms_overlap_thresh, t |= 1ULL << i; } } - const int col_blocks = THCCeilDiv(n_boxes, threadsPerBlock); + const int col_blocks = at::ceil_div(n_boxes, threadsPerBlock); dev_mask[cur_box_idx * col_blocks + col_start] = t; } } @@ -76,20 +77,19 @@ at::Tensor nms_cuda(const at::Tensor boxes, float nms_overlap_thresh) { int boxes_num = boxes.size(0); - const int col_blocks = THCCeilDiv(boxes_num, threadsPerBlock); + const int col_blocks = at::ceil_div(boxes_num, threadsPerBlock); scalar_t* boxes_dev = boxes_sorted.data(); - THCState *state = at::globalContext().lazyInitCUDA(); // TODO replace with getTHCState unsigned long long* mask_dev = NULL; //THCudaCheck(THCudaMalloc(state, (void**) &mask_dev, // boxes_num * col_blocks * sizeof(unsigned long long))); - mask_dev = (unsigned long long*) THCudaMalloc(state, boxes_num * col_blocks * sizeof(unsigned long long)); + mask_dev = (unsigned long long*) c10::cuda::CUDACachingAllocator::raw_alloc(boxes_num * col_blocks * sizeof(unsigned long long)); - dim3 blocks(THCCeilDiv(boxes_num, threadsPerBlock), - THCCeilDiv(boxes_num, threadsPerBlock)); + dim3 blocks(at::ceil_div(boxes_num, threadsPerBlock), + at::ceil_div(boxes_num, threadsPerBlock)); dim3 threads(threadsPerBlock); nms_kernel<<>>(boxes_num, nms_overlap_thresh, @@ -97,7 +97,7 @@ at::Tensor nms_cuda(const at::Tensor boxes, float nms_overlap_thresh) { mask_dev); std::vector mask_host(boxes_num * col_blocks); - THCudaCheck(cudaMemcpy(&mask_host[0], + AT_CUDA_CHECK(cudaMemcpy(&mask_host[0], mask_dev, sizeof(unsigned long long) * boxes_num * col_blocks, cudaMemcpyDeviceToHost)); @@ -122,7 +122,7 @@ at::Tensor nms_cuda(const at::Tensor boxes, float nms_overlap_thresh) { } } - THCudaFree(state, mask_dev); + c10::cuda::CUDACachingAllocator::raw_delete(mask_dev); // TODO improve this part return std::get<0>(order_t.index({ keep.narrow(/*dim=*/0, /*start=*/0, /*length=*/num_to_keep).to( diff --git a/fcos_core/utils/imports.py b/fcos_core/utils/imports.py index 53e27e2b..cc7f4b0e 100644 --- a/fcos_core/utils/imports.py +++ b/fcos_core/utils/imports.py @@ -1,7 +1,8 @@ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. import torch +import sys -if torch._six.PY3: +if sys.version_info.major == 3: import importlib import importlib.util import sys diff --git a/fcos_core/utils/model_zoo.py b/fcos_core/utils/model_zoo.py index 6c40686b..7a357c5a 100644 --- a/fcos_core/utils/model_zoo.py +++ b/fcos_core/utils/model_zoo.py @@ -3,11 +3,11 @@ import sys try: - from torch.utils.model_zoo import _download_url_to_file + from torch.utils.model_zoo import download_url_to_file from torch.utils.model_zoo import urlparse from torch.utils.model_zoo import HASH_REGEX except: - from torch.hub import _download_url_to_file + from torch.hub import download_url_to_file from torch.hub import urlparse from torch.hub import HASH_REGEX @@ -59,6 +59,6 @@ def cache_url(url, model_dir=None, progress=True): # if the hash_prefix is less than 6 characters if len(hash_prefix) < 6: hash_prefix = None - _download_url_to_file(url, cached_file, hash_prefix, progress=progress) + download_url_to_file(url, cached_file, hash_prefix, progress=progress) synchronize() return cached_file