Skip to content

Commit 651bfbd

Browse files
authored
Encapsulate and standardize new_empty_tensor_op (#3089)
* Renaming C++ files & methods according to recommended naming conventions and aligning them with Python's API. * Create foreach cpp file a separate header file with "public" functions. * Adding all internal functions in anonymous namespaces. * Convert to const ref all possible parameters. * Removing unnecessary repeated includes.
1 parent 3c3c625 commit 651bfbd

File tree

3 files changed

+20
-7
lines changed

3 files changed

+20
-7
lines changed
Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,35 @@
11
#pragma once
22

3-
// All pure C++ headers for the C++ frontend.
4-
#include <torch/all.h>
3+
#include "new_empty_tensor_op.h"
4+
#include <torch/extension.h>
5+
6+
namespace {
57

68
class NewEmptyTensorOp : public torch::autograd::Function<NewEmptyTensorOp> {
79
public:
810
static torch::autograd::variable_list forward(
911
torch::autograd::AutogradContext* ctx,
10-
torch::autograd::Variable input,
11-
c10::List<int64_t> new_shape) {
12+
const torch::autograd::Variable& input,
13+
const c10::List<int64_t>& new_shape) {
1214
ctx->saved_data["shape"] = input.sizes();
1315
std::vector<int64_t> shape(new_shape.begin(), new_shape.end());
1416
return {input.new_empty(shape, at::TensorOptions())};
1517
}
1618

1719
static torch::autograd::variable_list backward(
1820
torch::autograd::AutogradContext* ctx,
19-
torch::autograd::variable_list grad_output) {
21+
const torch::autograd::variable_list& grad_output) {
2022
// Use data saved in forward
2123
auto shape = ctx->saved_data["shape"].toIntList();
2224
auto out = forward(ctx, grad_output[0], shape);
2325
return {out[0], at::Tensor()};
2426
}
2527
};
2628

27-
at::Tensor new_empty_tensor(const at::Tensor& input, c10::List<int64_t> shape) {
29+
} // namespace
30+
31+
at::Tensor new_empty_tensor(
32+
const at::Tensor& input,
33+
const c10::List<int64_t>& shape) {
2834
return NewEmptyTensorOp::apply(input, shape)[0];
2935
}
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
#pragma once
2+
3+
#include <ATen/ATen.h>
4+
5+
at::Tensor new_empty_tensor(
6+
const at::Tensor& input,
7+
const c10::List<int64_t>& shape);

torchvision/csrc/vision.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
#endif
1010

1111
#include "deform_conv2d.h"
12-
#include "empty_tensor_op.h"
12+
#include "new_empty_tensor_op.h"
1313
#include "nms.h"
1414
#include "ps_roi_align.h"
1515
#include "ps_roi_pool.h"

0 commit comments

Comments
 (0)