Skip to content

Adding operator methods in vision::ops namespace. #3096

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion test/tracing/frcnn/test_frcnn_tracing.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
#ifdef _WIN32
// Windows only
// This is necessary until operators are automatically registered on include
static auto _nms = &nms_cpu;
static auto _nms = &vision::ops::nms_cpu;
#endif

int main() {
Expand Down
6 changes: 6 additions & 0 deletions torchvision/csrc/cpu/deform_conv2d_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,9 @@

#include "deform_conv2d_kernel.h"

namespace vision {
namespace ops {

namespace {

const int kMaxParallelImgs = 32;
Expand Down Expand Up @@ -1137,3 +1140,6 @@ deform_conv2d_backward_cpu(
return std::make_tuple(
grad_input, grad_weight, grad_offset, grad_mask, grad_bias);
}

} // namespace ops
} // namespace vision
6 changes: 6 additions & 0 deletions torchvision/csrc/cpu/deform_conv2d_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
#include <ATen/ATen.h>
#include "../macros.h"

namespace vision {
namespace ops {

VISION_API at::Tensor deform_conv2d_forward_cpu(
const at::Tensor& input,
const at::Tensor& weight,
Expand Down Expand Up @@ -37,3 +40,6 @@ VISION_API std::
int64_t n_weight_grps,
int64_t n_offset_grps,
bool use_mask);

} // namespace ops
} // namespace vision
6 changes: 6 additions & 0 deletions torchvision/csrc/cpu/nms_kernel.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
#include "nms_kernel.h"

namespace vision {
namespace ops {

namespace {

template <typename scalar_t>
Expand Down Expand Up @@ -103,3 +106,6 @@ at::Tensor nms_cpu(
});
return result;
}

} // namespace ops
} // namespace vision
6 changes: 6 additions & 0 deletions torchvision/csrc/cpu/nms_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,13 @@
#include <ATen/ATen.h>
#include "../macros.h"

namespace vision {
namespace ops {

VISION_API at::Tensor nms_cpu(
const at::Tensor& dets,
const at::Tensor& scores,
double iou_threshold);

} // namespace ops
} // namespace vision
6 changes: 6 additions & 0 deletions torchvision/csrc/cpu/ps_roi_align_kernel.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
#include "ps_roi_align_kernel.h"

namespace vision {
namespace ops {

namespace {

template <typename T>
Expand Down Expand Up @@ -416,3 +419,6 @@ at::Tensor ps_roi_align_backward_cpu(
});
return grad_input;
}

} // namespace ops
} // namespace vision
6 changes: 6 additions & 0 deletions torchvision/csrc/cpu/ps_roi_align_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
#include <ATen/ATen.h>
#include "../macros.h"

namespace vision {
namespace ops {

VISION_API std::tuple<at::Tensor, at::Tensor> ps_roi_align_forward_cpu(
const at::Tensor& input,
const at::Tensor& rois,
Expand All @@ -23,3 +26,6 @@ VISION_API at::Tensor ps_roi_align_backward_cpu(
int64_t channels,
int64_t height,
int64_t width);

} // namespace ops
} // namespace vision
6 changes: 6 additions & 0 deletions torchvision/csrc/cpu/ps_roi_pool_kernel.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
#include "ps_roi_pool_kernel.h"

namespace vision {
namespace ops {

namespace {

template <class T>
Expand Down Expand Up @@ -255,3 +258,6 @@ at::Tensor ps_roi_pool_backward_cpu(
});
return grad_input;
}

} // namespace ops
} // namespace vision
6 changes: 6 additions & 0 deletions torchvision/csrc/cpu/ps_roi_pool_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
#include <ATen/ATen.h>
#include "../macros.h"

namespace vision {
namespace ops {

VISION_API std::tuple<at::Tensor, at::Tensor> ps_roi_pool_forward_cpu(
const at::Tensor& input,
const at::Tensor& rois,
Expand All @@ -21,3 +24,6 @@ VISION_API at::Tensor ps_roi_pool_backward_cpu(
int64_t channels,
int64_t height,
int64_t width);

} // namespace ops
} // namespace vision
6 changes: 6 additions & 0 deletions torchvision/csrc/cpu/roi_align_kernel.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
#include "roi_align_kernel.h"

namespace vision {
namespace ops {

namespace {

// implementation taken from Caffe2
Expand Down Expand Up @@ -494,3 +497,6 @@ at::Tensor roi_align_backward_cpu(
});
return grad_input;
}

} // namespace ops
} // namespace vision
6 changes: 6 additions & 0 deletions torchvision/csrc/cpu/roi_align_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
#include <ATen/ATen.h>
#include "../macros.h"

namespace vision {
namespace ops {

VISION_API at::Tensor roi_align_forward_cpu(
const at::Tensor& input,
const at::Tensor& rois,
Expand All @@ -24,3 +27,6 @@ VISION_API at::Tensor roi_align_backward_cpu(
int64_t width,
int64_t sampling_ratio,
bool aligned);

} // namespace ops
} // namespace vision
6 changes: 6 additions & 0 deletions torchvision/csrc/cpu/roi_pool_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@

#include "roi_pool_kernel.h"

namespace vision {
namespace ops {

namespace {

template <class T>
Expand Down Expand Up @@ -231,3 +234,6 @@ at::Tensor roi_pool_backward_cpu(
});
return grad_input;
}

} // namespace ops
} // namespace vision
6 changes: 6 additions & 0 deletions torchvision/csrc/cpu/roi_pool_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
#include <ATen/ATen.h>
#include "../macros.h"

namespace vision {
namespace ops {

VISION_API std::tuple<at::Tensor, at::Tensor> roi_pool_forward_cpu(
const at::Tensor& input,
const at::Tensor& rois,
Expand All @@ -21,3 +24,6 @@ VISION_API at::Tensor roi_pool_backward_cpu(
int64_t channels,
int64_t height,
int64_t width);

} // namespace ops
} // namespace vision
6 changes: 6 additions & 0 deletions torchvision/csrc/cuda/cuda_helpers.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
#pragma once

namespace vision {
namespace ops {

#define CUDA_1D_KERNEL_LOOP(i, n) \
for (int i = (blockIdx.x * blockDim.x) + threadIdx.x; i < (n); \
i += (blockDim.x * gridDim.x))
Expand All @@ -8,3 +11,6 @@ template <typename integer>
constexpr __host__ __device__ inline integer ceil_div(integer n, integer m) {
return (n + m - 1) / m;
}

} // namespace ops
} // namespace vision
6 changes: 6 additions & 0 deletions torchvision/csrc/cuda/deform_conv2d_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,9 @@
#include "cuda_helpers.h"
#include "deform_conv2d_kernel.h"

namespace vision {
namespace ops {

namespace {

const int kMaxParallelImgs = 32;
Expand Down Expand Up @@ -1183,3 +1186,6 @@ deform_conv2d_backward_cuda(
return std::make_tuple(
grad_input, grad_weight, grad_offset, grad_mask, grad_bias);
}

} // namespace ops
} // namespace vision
6 changes: 6 additions & 0 deletions torchvision/csrc/cuda/deform_conv2d_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
#include <ATen/ATen.h>
#include "../macros.h"

namespace vision {
namespace ops {

VISION_API at::Tensor deform_conv2d_forward_cuda(
const at::Tensor& input,
const at::Tensor& weight,
Expand Down Expand Up @@ -37,3 +40,6 @@ VISION_API std::
int64_t n_weight_grps,
int64_t n_offset_grps,
bool use_mask);

} // namespace ops
} // namespace vision
6 changes: 6 additions & 0 deletions torchvision/csrc/cuda/nms_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
#include "cuda_helpers.h"
#include "nms_kernel.h"

namespace vision {
namespace ops {

namespace {

int const threadsPerBlock = sizeof(unsigned long long) * 8;
Expand Down Expand Up @@ -162,3 +165,6 @@ at::Tensor nms_cuda(
{keep.narrow(/*dim=*/0, /*start=*/0, /*length=*/num_to_keep)
.to(order_t.device(), keep.scalar_type())});
}

} // namespace ops
} // namespace vision
6 changes: 6 additions & 0 deletions torchvision/csrc/cuda/nms_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,13 @@
#include <ATen/ATen.h>
#include "../macros.h"

namespace vision {
namespace ops {

VISION_API at::Tensor nms_cuda(
const at::Tensor& dets,
const at::Tensor& scores,
double iou_threshold);

} // namespace ops
} // namespace vision
6 changes: 6 additions & 0 deletions torchvision/csrc/cuda/ps_roi_align_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
#include "cuda_helpers.h"
#include "ps_roi_align_kernel.h"

namespace vision {
namespace ops {

namespace {

template <typename T>
Expand Down Expand Up @@ -434,3 +437,6 @@ at::Tensor ps_roi_align_backward_cuda(
AT_CUDA_CHECK(cudaGetLastError());
return grad_input;
}

} // namespace ops
} // namespace vision
6 changes: 6 additions & 0 deletions torchvision/csrc/cuda/ps_roi_align_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
#include <ATen/ATen.h>
#include "../macros.h"

namespace vision {
namespace ops {

VISION_API std::tuple<at::Tensor, at::Tensor> ps_roi_align_forward_cuda(
const at::Tensor& input,
const at::Tensor& rois,
Expand All @@ -23,3 +26,6 @@ VISION_API at::Tensor ps_roi_align_backward_cuda(
int64_t channels,
int64_t height,
int64_t width);

} // namespace ops
} // namespace vision
6 changes: 6 additions & 0 deletions torchvision/csrc/cuda/ps_roi_pool_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
#include "cuda_helpers.h"
#include "ps_roi_pool_kernel.h"

namespace vision {
namespace ops {

namespace {

template <typename T>
Expand Down Expand Up @@ -270,3 +273,6 @@ at::Tensor ps_roi_pool_backward_cuda(
AT_CUDA_CHECK(cudaGetLastError());
return grad_input;
}

} // namespace ops
} // namespace vision
6 changes: 6 additions & 0 deletions torchvision/csrc/cuda/ps_roi_pool_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
#include <ATen/ATen.h>
#include "../macros.h"

namespace vision {
namespace ops {

VISION_API std::tuple<at::Tensor, at::Tensor> ps_roi_pool_forward_cuda(
const at::Tensor& input,
const at::Tensor& rois,
Expand All @@ -21,3 +24,6 @@ VISION_API at::Tensor ps_roi_pool_backward_cuda(
int64_t channels,
int64_t height,
int64_t width);

} // namespace ops
} // namespace vision
6 changes: 6 additions & 0 deletions torchvision/csrc/cuda/roi_align_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
#include "cuda_helpers.h"
#include "roi_align_kernel.h"

namespace vision {
namespace ops {

namespace {

template <typename T>
Expand Down Expand Up @@ -443,3 +446,6 @@ at::Tensor roi_align_backward_cuda(
AT_CUDA_CHECK(cudaGetLastError());
return grad_input;
}

} // namespace ops
} // namespace vision
6 changes: 6 additions & 0 deletions torchvision/csrc/cuda/roi_align_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
#include <ATen/ATen.h>
#include "../macros.h"

namespace vision {
namespace ops {

VISION_API at::Tensor roi_align_forward_cuda(
const at::Tensor& input,
const at::Tensor& rois,
Expand All @@ -24,3 +27,6 @@ VISION_API at::Tensor roi_align_backward_cuda(
int64_t width,
int64_t sampling_ratio,
bool aligned);

} // namespace ops
} // namespace vision
Loading