diff --git a/torchvision/csrc/empty_tensor_op.h b/torchvision/csrc/new_empty_tensor_op.cpp similarity index 67% rename from torchvision/csrc/empty_tensor_op.h rename to torchvision/csrc/new_empty_tensor_op.cpp index 99448109762..e4f31600c54 100644 --- a/torchvision/csrc/empty_tensor_op.h +++ b/torchvision/csrc/new_empty_tensor_op.cpp @@ -1,14 +1,16 @@ #pragma once -// All pure C++ headers for the C++ frontend. -#include +#include "new_empty_tensor_op.h" +#include + +namespace { class NewEmptyTensorOp : public torch::autograd::Function { public: static torch::autograd::variable_list forward( torch::autograd::AutogradContext* ctx, - torch::autograd::Variable input, - c10::List new_shape) { + const torch::autograd::Variable& input, + const c10::List& new_shape) { ctx->saved_data["shape"] = input.sizes(); std::vector shape(new_shape.begin(), new_shape.end()); return {input.new_empty(shape, at::TensorOptions())}; @@ -16,7 +18,7 @@ class NewEmptyTensorOp : public torch::autograd::Function { static torch::autograd::variable_list backward( torch::autograd::AutogradContext* ctx, - torch::autograd::variable_list grad_output) { + const torch::autograd::variable_list& grad_output) { // Use data saved in forward auto shape = ctx->saved_data["shape"].toIntList(); auto out = forward(ctx, grad_output[0], shape); @@ -24,6 +26,10 @@ class NewEmptyTensorOp : public torch::autograd::Function { } }; -at::Tensor new_empty_tensor(const at::Tensor& input, c10::List shape) { +} // namespace + +at::Tensor new_empty_tensor( + const at::Tensor& input, + const c10::List& shape) { return NewEmptyTensorOp::apply(input, shape)[0]; } diff --git a/torchvision/csrc/new_empty_tensor_op.h b/torchvision/csrc/new_empty_tensor_op.h new file mode 100644 index 00000000000..75f4cd5a7fe --- /dev/null +++ b/torchvision/csrc/new_empty_tensor_op.h @@ -0,0 +1,7 @@ +#pragma once + +#include + +at::Tensor new_empty_tensor( + const at::Tensor& input, + const c10::List& shape); diff --git a/torchvision/csrc/vision.cpp b/torchvision/csrc/vision.cpp index d764ec9334b..fb0bf014912 100644 --- a/torchvision/csrc/vision.cpp +++ b/torchvision/csrc/vision.cpp @@ -9,7 +9,7 @@ #endif #include "deform_conv2d.h" -#include "empty_tensor_op.h" +#include "new_empty_tensor_op.h" #include "nms.h" #include "ps_roi_align.h" #include "ps_roi_pool.h"