Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix issues with deprecated library THC #383

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion fcos_core/csrc/cpu/ROIAlign_cpu.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
#include "cpu/vision.h"
#include "vision.h"

// implementation taken from Caffe2
template <typename T>
Expand Down
2 changes: 1 addition & 1 deletion fcos_core/csrc/cpu/nms_cpu.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
#include "cpu/vision.h"
#include "vision.h"


template <typename scalar_t>
Expand Down
14 changes: 7 additions & 7 deletions fcos_core/csrc/cuda/ROIAlign_cuda.cu
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/ceil_div.h>

#include <THC/THC.h>
#include <THC/THCAtomics.cuh>
#include <THC/THCDeviceUtils.cuh>

Expand Down Expand Up @@ -272,11 +272,11 @@ at::Tensor ROIAlign_forward_cuda(const at::Tensor& input,
auto output_size = num_rois * pooled_height * pooled_width * channels;
cudaStream_t stream = at::cuda::getCurrentCUDAStream();

dim3 grid(std::min(THCCeilDiv((long)output_size, 512L), 4096L));
dim3 grid(std::min(at::ceil_div((long)output_size, 512L), 4096L));
dim3 block(512);

if (output.numel() == 0) {
THCudaCheck(cudaGetLastError());
AT_CUDA_CHECK(cudaGetLastError());
return output;
}

Expand All @@ -294,7 +294,7 @@ at::Tensor ROIAlign_forward_cuda(const at::Tensor& input,
rois.contiguous().data<scalar_t>(),
output.data<scalar_t>());
});
THCudaCheck(cudaGetLastError());
AT_CUDA_CHECK(cudaGetLastError());
return output;
}

Expand All @@ -317,12 +317,12 @@ at::Tensor ROIAlign_backward_cuda(const at::Tensor& grad,

cudaStream_t stream = at::cuda::getCurrentCUDAStream();

dim3 grid(std::min(THCCeilDiv((long)grad.numel(), 512L), 4096L));
dim3 grid(std::min(at::ceil_div((long)grad.numel(), 512L), 4096L));
dim3 block(512);

// handle possibly empty gradients
if (grad.numel() == 0) {
THCudaCheck(cudaGetLastError());
AT_CUDA_CHECK(cudaGetLastError());
return grad_input;
}

Expand All @@ -341,6 +341,6 @@ at::Tensor ROIAlign_backward_cuda(const at::Tensor& grad,
grad_input.data<scalar_t>(),
rois.contiguous().data<scalar_t>());
});
THCudaCheck(cudaGetLastError());
AT_CUDA_CHECK(cudaGetLastError());
return grad_input;
}
14 changes: 7 additions & 7 deletions fcos_core/csrc/cuda/ROIPool_cuda.cu
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/ceil_div.h>

#include <THC/THC.h>
#include <THC/THCAtomics.cuh>
#include <THC/THCDeviceUtils.cuh>

Expand Down Expand Up @@ -126,11 +126,11 @@ std::tuple<at::Tensor, at::Tensor> ROIPool_forward_cuda(const at::Tensor& input,

cudaStream_t stream = at::cuda::getCurrentCUDAStream();

dim3 grid(std::min(THCCeilDiv((long)output_size, 512L), 4096L));
dim3 grid(std::min(at::ceil_div((long)output_size, 512L), 4096L));
dim3 block(512);

if (output.numel() == 0) {
THCudaCheck(cudaGetLastError());
AT_CUDA_CHECK(cudaGetLastError());
return std::make_tuple(output, argmax);
}

Expand All @@ -148,7 +148,7 @@ std::tuple<at::Tensor, at::Tensor> ROIPool_forward_cuda(const at::Tensor& input,
output.data<scalar_t>(),
argmax.data<int>());
});
THCudaCheck(cudaGetLastError());
AT_CUDA_CHECK(cudaGetLastError());
return std::make_tuple(output, argmax);
}

Expand All @@ -173,12 +173,12 @@ at::Tensor ROIPool_backward_cuda(const at::Tensor& grad,

cudaStream_t stream = at::cuda::getCurrentCUDAStream();

dim3 grid(std::min(THCCeilDiv((long)grad.numel(), 512L), 4096L));
dim3 grid(std::min(at::ceil_div((long)grad.numel(), 512L), 4096L));
dim3 block(512);

// handle possibly empty gradients
if (grad.numel() == 0) {
THCudaCheck(cudaGetLastError());
AT_CUDA_CHECK(cudaGetLastError());
return grad_input;
}

Expand All @@ -197,6 +197,6 @@ at::Tensor ROIPool_backward_cuda(const at::Tensor& grad,
grad_input.data<scalar_t>(),
rois.contiguous().data<scalar_t>());
});
THCudaCheck(cudaGetLastError());
AT_CUDA_CHECK(cudaGetLastError());
return grad_input;
}
13 changes: 6 additions & 7 deletions fcos_core/csrc/cuda/SigmoidFocalLoss_cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>

#include <THC/THC.h>
#include <THC/THCAtomics.cuh>
#include <THC/THCDeviceUtils.cuh>

Expand Down Expand Up @@ -117,11 +116,11 @@ at::Tensor SigmoidFocalLoss_forward_cuda(
auto losses_size = num_samples * logits.size(1);
cudaStream_t stream = at::cuda::getCurrentCUDAStream();

dim3 grid(std::min(THCCeilDiv((long)losses_size, 512L), 4096L));
dim3 grid(std::min(at::ceil_div((long)losses_size, 512L), 4096L));
dim3 block(512);

if (losses.numel() == 0) {
THCudaCheck(cudaGetLastError());
AT_CUDA_CHECK(cudaGetLastError());
return losses;
}

Expand All @@ -136,7 +135,7 @@ at::Tensor SigmoidFocalLoss_forward_cuda(
num_samples,
losses.data<scalar_t>());
});
THCudaCheck(cudaGetLastError());
AT_CUDA_CHECK(cudaGetLastError());
return losses;
}

Expand All @@ -161,11 +160,11 @@ at::Tensor SigmoidFocalLoss_backward_cuda(
auto d_logits_size = num_samples * logits.size(1);
cudaStream_t stream = at::cuda::getCurrentCUDAStream();

dim3 grid(std::min(THCCeilDiv((long)d_logits_size, 512L), 4096L));
dim3 grid(std::min(at::ceil_div((long)d_logits_size, 512L), 4096L));
dim3 block(512);

if (d_logits.numel() == 0) {
THCudaCheck(cudaGetLastError());
AT_CUDA_CHECK(cudaGetLastError());
return d_logits;
}

Expand All @@ -182,7 +181,7 @@ at::Tensor SigmoidFocalLoss_backward_cuda(
d_logits.data<scalar_t>());
});

THCudaCheck(cudaGetLastError());
AT_CUDA_CHECK(cudaGetLastError());
return d_logits;
}

43 changes: 21 additions & 22 deletions fcos_core/csrc/cuda/deform_conv_cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>

#include <THC/THC.h>
#include <THC/THCDeviceUtils.cuh>

#include <vector>
Expand Down Expand Up @@ -69,26 +68,26 @@ void shape_check(at::Tensor input, at::Tensor offset, at::Tensor *gradOutput,
int padW, int dilationH, int dilationW, int group,
int deformable_group)
{
AT_CHECK(weight.ndimension() == 4,
TORCH_CHECK(weight.ndimension() == 4,
"4D weight tensor (nOutputPlane,nInputPlane,kH,kW) expected, "
"but got: %s",
weight.ndimension());

AT_CHECK(weight.is_contiguous(), "weight tensor has to be contiguous");
TORCH_CHECK(weight.is_contiguous(), "weight tensor has to be contiguous");

AT_CHECK(kW > 0 && kH > 0,
TORCH_CHECK(kW > 0 && kH > 0,
"kernel size should be greater than zero, but got kH: %d kW: %d", kH,
kW);

AT_CHECK((weight.size(2) == kH && weight.size(3) == kW),
TORCH_CHECK((weight.size(2) == kH && weight.size(3) == kW),
"kernel size should be consistent with weight, ",
"but got kH: %d kW: %d weight.size(2): %d, weight.size(3): %d", kH,
kW, weight.size(2), weight.size(3));

AT_CHECK(dW > 0 && dH > 0,
TORCH_CHECK(dW > 0 && dH > 0,
"stride should be greater than zero, but got dH: %d dW: %d", dH, dW);

AT_CHECK(
TORCH_CHECK(
dilationW > 0 && dilationH > 0,
"dilation should be greater than 0, but got dilationH: %d dilationW: %d",
dilationH, dilationW);
Expand All @@ -104,7 +103,7 @@ void shape_check(at::Tensor input, at::Tensor offset, at::Tensor *gradOutput,
dimw++;
}

AT_CHECK(ndim == 3 || ndim == 4, "3D or 4D input tensor expected but got: %s",
TORCH_CHECK(ndim == 3 || ndim == 4, "3D or 4D input tensor expected but got: %s",
ndim);

long nInputPlane = weight.size(1) * group;
Expand All @@ -116,7 +115,7 @@ void shape_check(at::Tensor input, at::Tensor offset, at::Tensor *gradOutput,
long outputWidth =
(inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1;

AT_CHECK(nInputPlane % deformable_group == 0,
TORCH_CHECK(nInputPlane % deformable_group == 0,
"input channels must divide deformable group size");

if (outputWidth < 1 || outputHeight < 1)
Expand All @@ -126,27 +125,27 @@ void shape_check(at::Tensor input, at::Tensor offset, at::Tensor *gradOutput,
nInputPlane, inputHeight, inputWidth, nOutputPlane, outputHeight,
outputWidth);

AT_CHECK(input.size(1) == nInputPlane,
TORCH_CHECK(input.size(1) == nInputPlane,
"invalid number of input planes, expected: %d, but got: %d",
nInputPlane, input.size(1));

AT_CHECK((inputHeight >= kH && inputWidth >= kW),
TORCH_CHECK((inputHeight >= kH && inputWidth >= kW),
"input image is smaller than kernel");

AT_CHECK((offset.size(2) == outputHeight && offset.size(3) == outputWidth),
TORCH_CHECK((offset.size(2) == outputHeight && offset.size(3) == outputWidth),
"invalid spatial size of offset, expected height: %d width: %d, but "
"got height: %d width: %d",
outputHeight, outputWidth, offset.size(2), offset.size(3));

AT_CHECK((offset.size(1) == deformable_group * 2 * kH * kW),
TORCH_CHECK((offset.size(1) == deformable_group * 2 * kH * kW),
"invalid number of channels of offset");

if (gradOutput != NULL) {
AT_CHECK(gradOutput->size(dimf) == nOutputPlane,
TORCH_CHECK(gradOutput->size(dimf) == nOutputPlane,
"invalid number of gradOutput planes, expected: %d, but got: %d",
nOutputPlane, gradOutput->size(dimf));

AT_CHECK((gradOutput->size(dimh) == outputHeight &&
TORCH_CHECK((gradOutput->size(dimh) == outputHeight &&
gradOutput->size(dimw) == outputWidth),
"invalid size of gradOutput, expected height: %d width: %d , but "
"got height: %d width: %d",
Expand Down Expand Up @@ -197,7 +196,7 @@ int deform_conv_forward_cuda(at::Tensor input, at::Tensor weight,
long outputHeight =
(inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1;

AT_CHECK((offset.size(0) == batchSize), "invalid batch size of offset");
TORCH_CHECK((offset.size(0) == batchSize), "invalid batch size of offset");

output = output.view({batchSize / im2col_step, im2col_step, nOutputPlane,
outputHeight, outputWidth});
Expand Down Expand Up @@ -304,7 +303,7 @@ int deform_conv_backward_input_cuda(at::Tensor input, at::Tensor offset,
long outputHeight =
(inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1;

AT_CHECK((offset.size(0) == batchSize), 3, "invalid batch size of offset");
TORCH_CHECK((offset.size(0) == batchSize), 3, "invalid batch size of offset");
gradInput = gradInput.view({batchSize, nInputPlane, inputHeight, inputWidth});
columns = at::zeros(
{nInputPlane * kW * kH, im2col_step * outputHeight * outputWidth},
Expand Down Expand Up @@ -420,7 +419,7 @@ int deform_conv_backward_parameters_cuda(
long outputHeight =
(inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1;

AT_CHECK((offset.size(0) == batchSize), "invalid batch size of offset");
TORCH_CHECK((offset.size(0) == batchSize), "invalid batch size of offset");

columns = at::zeros(
{nInputPlane * kW * kH, im2col_step * outputHeight * outputWidth},
Expand Down Expand Up @@ -501,8 +500,8 @@ void modulated_deform_conv_cuda_forward(
const int dilation_w, const int group, const int deformable_group,
const bool with_bias)
{
AT_CHECK(input.is_contiguous(), "input tensor has to be contiguous");
AT_CHECK(weight.is_contiguous(), "weight tensor has to be contiguous");
TORCH_CHECK(input.is_contiguous(), "input tensor has to be contiguous");
TORCH_CHECK(weight.is_contiguous(), "weight tensor has to be contiguous");

const int batch = input.size(0);
const int channels = input.size(1);
Expand Down Expand Up @@ -583,8 +582,8 @@ void modulated_deform_conv_cuda_backward(
int pad_w, int dilation_h, int dilation_w, int group, int deformable_group,
const bool with_bias)
{
AT_CHECK(input.is_contiguous(), "input tensor has to be contiguous");
AT_CHECK(weight.is_contiguous(), "weight tensor has to be contiguous");
TORCH_CHECK(input.is_contiguous(), "input tensor has to be contiguous");
TORCH_CHECK(weight.is_contiguous(), "weight tensor has to be contiguous");

const int batch = input.size(0);
const int channels = input.size(1);
Expand Down
7 changes: 3 additions & 4 deletions fcos_core/csrc/cuda/deform_pool_cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>

#include <THC/THC.h>
#include <THC/THCDeviceUtils.cuh>

#include <vector>
Expand Down Expand Up @@ -39,7 +38,7 @@ void deform_psroi_pooling_cuda_forward(
const int output_dim, const int group_size, const int pooled_size,
const int part_size, const int sample_per_part, const float trans_std)
{
AT_CHECK(input.is_contiguous(), "input tensor has to be contiguous");
TORCH_CHECK(input.is_contiguous(), "input tensor has to be contiguous");

const int batch = input.size(0);
const int channels = input.size(1);
Expand All @@ -65,8 +64,8 @@ void deform_psroi_pooling_cuda_backward(
const int group_size, const int pooled_size, const int part_size,
const int sample_per_part, const float trans_std)
{
AT_CHECK(out_grad.is_contiguous(), "out_grad tensor has to be contiguous");
AT_CHECK(input.is_contiguous(), "input tensor has to be contiguous");
TORCH_CHECK(out_grad.is_contiguous(), "out_grad tensor has to be contiguous");
TORCH_CHECK(input.is_contiguous(), "input tensor has to be contiguous");

const int batch = input.size(0);
const int channels = input.size(1);
Expand Down
Loading