Skip to content

Commit 9aae82b

Browse files
Ailing Zhangfacebook-github-bot
authored andcommitted
Improvements for current AD (pytorch#17187)
Summary: This PR removes a few size of `self` that passed from forward pass to backward pass when `self` is already required in backward pass. This could be reason that cause the potential slow down in pytorch#16689 . I will attach a few perf numbers (still a bit volatile among runs tho) I got in the comment. Pull Request resolved: pytorch#17187 Differential Revision: D14179512 Pulled By: ailzhang fbshipit-source-id: 5f3b1f6f26a3fef6dec15623b940380cc13656fa
1 parent e422b27 commit 9aae82b

16 files changed

+837
-478
lines changed

aten/src/ATen/native/ReduceOps.cpp

Lines changed: 0 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -335,42 +335,6 @@ Tensor& sum_out(Tensor& result, const Tensor& self, IntArrayRef dim, ScalarType
335335
return at::native::sum_out(result, self, dim, false, dtype);
336336
}
337337

338-
int64_t _safe_size(IntArrayRef sizes, IntArrayRef dim) {
339-
int64_t size = 1;
340-
if (sizes.size() == 0) {
341-
return 1;
342-
}
343-
for (auto d : dim) {
344-
d = at::maybe_wrap_dim(d, sizes.size());
345-
size *= sizes[d];
346-
}
347-
return size;
348-
}
349-
350-
Tensor unsqueeze_multiple(const Tensor & t, IntArrayRef dim, size_t n_dims) {
351-
auto dims_to_unsqueeze = at::dim_list_to_bitset(dim, n_dims);
352-
Tensor res = t;
353-
for (size_t i = 0; i < n_dims; i++){
354-
if (dims_to_unsqueeze[i]) {
355-
res = res.unsqueeze(i);
356-
}
357-
}
358-
return res;
359-
}
360-
361-
Tensor sum_backward(const Tensor & grad, IntArrayRef sizes, IntArrayRef dims, bool keepdim) {
362-
if (!keepdim && sizes.size() > 0) {
363-
if (dims.size()==1) {
364-
return grad.unsqueeze(dims[0]).expand(sizes);
365-
} else {
366-
Tensor res = unsqueeze_multiple(grad, dims, sizes.size());
367-
return res.expand(sizes);
368-
}
369-
} else {
370-
return grad.expand(sizes);
371-
}
372-
}
373-
374338
Tensor& prod_out(Tensor& result, const Tensor& self, int64_t dim, bool keepdim, ScalarType dtype) {
375339
return at::native::prod_out(
376340
result, self, dim, keepdim, c10::optional<ScalarType>(dtype));
@@ -452,16 +416,6 @@ Tensor logsumexp(const Tensor &self, IntArrayRef dims, bool keepdim) {
452416
return at::native::logsumexp_out(result, self, dims, keepdim);
453417
}
454418

455-
Tensor logsumexp_backward(const Tensor& grad, const Tensor & self, const Tensor& res, IntArrayRef dim, bool keepdim) {
456-
Tensor grad_input = grad;
457-
Tensor fwd_res = res;
458-
if (!keepdim && self.dim() != 0) {
459-
grad_input = unsqueeze_multiple(grad, dim, self.sizes().size());
460-
fwd_res = unsqueeze_multiple(res, dim, self.sizes().size());
461-
}
462-
return grad_input * (self - fwd_res).exp();
463-
}
464-
465419
static Tensor& norm_out(Tensor &result, const Tensor &self, optional<Scalar> opt_p,
466420
IntArrayRef dim, bool keepdim, optional<ScalarType> opt_dtype) {
467421
auto p = opt_p.value_or(2.0);
@@ -674,21 +628,6 @@ Tensor &var_out(Tensor &result, const Tensor &self, IntArrayRef dim, bool unbias
674628
return std_var_out(result, self, dim, unbiased, keepdim, false);
675629
}
676630

677-
Tensor var_backward(const Tensor & grad, const Tensor & self, bool unbiased) {
678-
return (2.0 / (self.numel() - unbiased)) * grad * (self - self.mean());
679-
}
680-
681-
Tensor var_backward(const Tensor & grad, const Tensor & self, IntArrayRef dim, bool unbiased, bool keepdim) {
682-
if (self.dim() == 0) {
683-
return at::var_backward(grad, self, unbiased);
684-
}
685-
Tensor unsqueezed_grad = grad;
686-
if (!keepdim && self.dim() > 1) {
687-
unsqueezed_grad = unsqueeze_multiple(grad, dim, self.sizes().size());
688-
}
689-
return (2.0 / (at::_safe_size(self.sizes(), dim) - unbiased)) * unsqueezed_grad * (self - self.mean(dim, true));
690-
}
691-
692631
Tensor std(const Tensor& self, bool unbiased) {
693632
AT_CHECK(self.type().backend() == Backend::CPU || self.type().backend() == Backend::CUDA,
694633
"std only supports CPU AND CUDA backend, got: ", toString(self.type().backend()));

aten/src/ATen/native/TensorCompare.cpp

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -34,14 +34,6 @@ namespace at { namespace native {
3434
DEFINE_DISPATCH(max_kernel);
3535
DEFINE_DISPATCH(min_kernel);
3636

37-
Tensor index_select_backward(const Tensor& grad, int64_t dim, const Tensor& indices, IntArrayRef sizes, bool keepdim) {
38-
Tensor res = at::zeros(sizes, grad.options());
39-
if (!keepdim && sizes.size() > 0) {
40-
return res.scatter_(dim, indices.unsqueeze(dim), grad.unsqueeze(dim));
41-
}
42-
return res.scatter_(dim, indices, grad);
43-
}
44-
4537
bool allclose(const Tensor& self, const Tensor& other, double rtol, double atol, bool equal_nan) {
4638
return at::isclose(self, other, rtol, atol, equal_nan).all().item<uint8_t>();
4739
}

aten/src/ATen/native/TensorShape.cpp

Lines changed: 0 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -384,16 +384,6 @@ Tensor permute(const Tensor& self, IntArrayRef dims) {
384384
return self.as_strided(newSizes, newStrides);
385385
}
386386

387-
Tensor permute_backwards(const Tensor & grad, IntArrayRef fwd_dims) {
388-
// invert the permutation
389-
auto ndims = fwd_dims.size();
390-
std::vector<int64_t> dims(ndims);
391-
for (size_t i = 0; i < ndims; i++) {
392-
dims[at::maybe_wrap_dim(fwd_dims[i], ndims)] = i;
393-
}
394-
return grad.permute(dims);
395-
}
396-
397387
Tensor repeat(const Tensor& self, IntArrayRef repeats) {
398388
AT_CHECK(repeats.size() >= (size_t)self.dim(),
399389
"Number of dimensions of repeat dims can not be smaller than number of dimensions of tensor");
@@ -461,12 +451,6 @@ Tensor select(const Tensor& self, int64_t dim, int64_t index) {
461451
return self.as_strided(sizes, strides, storage_offset);
462452
}
463453

464-
Tensor select_backward(const Tensor& grad, IntArrayRef input_sizes, int64_t dim, int64_t index) {
465-
auto grad_input = at::zeros(input_sizes, grad.options());
466-
grad_input.select(dim, index).copy_(grad);
467-
return grad_input;
468-
}
469-
470454
Tensor slice(const Tensor& self, int64_t dim, int64_t start, int64_t end, int64_t step) {
471455
int64_t ndim = self.dim();
472456
if (ndim == 0) {
@@ -500,12 +484,6 @@ Tensor slice(const Tensor& self, int64_t dim, int64_t start, int64_t end, int64_
500484
return self.as_strided(sizes, strides, storage_offset);
501485
}
502486

503-
Tensor slice_backward(const Tensor& grad, IntArrayRef input_sizes, int64_t dim, int64_t start, int64_t end, int64_t step) {
504-
auto grad_input = at::zeros(input_sizes, grad.options());
505-
grad_input.slice(dim, start, end, step).copy_(grad);
506-
return grad_input;
507-
}
508-
509487
std::vector<Tensor> split(const Tensor& self, int64_t split_size, int64_t dim) {
510488
AT_CHECK(self.dim() != 0, "split expects at least a 1-dimensional tensor");
511489
AT_CHECK(split_size >= 0, "split expects split_size be non-negative, but got split_size=", split_size);
@@ -712,28 +690,6 @@ Tensor squeeze(const Tensor& self) {
712690
return self.as_strided(std::get<0>(g), std::get<1>(g));
713691
}
714692

715-
Tensor unsqueeze_to(const Tensor & self, IntArrayRef sizes) {
716-
auto result = self;
717-
718-
int64_t nDims = sizes.size();
719-
for (int64_t dim = 0; dim < nDims; dim++) {
720-
if (sizes[dim] == 1) {
721-
result = result.unsqueeze(dim);
722-
}
723-
}
724-
return result;
725-
}
726-
727-
Tensor unsqueeze_to(const Tensor & self, int64_t dim, IntArrayRef sizes) {
728-
dim = at::maybe_wrap_dim(dim, sizes.size());
729-
// in NumPy it's not an error to unsqueeze a scalar, but we still need to avoided
730-
// unsqueezing in the backward.
731-
if (sizes.size() > 0 && sizes[dim] == 1) {
732-
return self.unsqueeze(dim);
733-
}
734-
return self;
735-
}
736-
737693
Tensor squeeze(const Tensor& self, int64_t dim) {
738694
int64_t dims = self.dim();
739695
dim = maybe_wrap_dim(dim, dims);

aten/src/ATen/native/native_functions.yaml

Lines changed: 0 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -59,10 +59,6 @@
5959
dispatch:
6060
CUDA: _cudnn_init_dropout_state
6161

62-
- func: index_select_backward(Tensor grad, int64_t dim, Tensor indices, int[] sizes, bool keepdim) -> Tensor
63-
64-
- func: select_backward(Tensor grad, int[] input_sizes, int64_t dim, int64_t index) -> Tensor
65-
6662
- func: _fused_dropout(Tensor self, float p, Generator? generator=None) -> (Tensor, Tensor)
6763
matches_jit_signature: True
6864
variants: function
@@ -1335,9 +1331,6 @@
13351331
- func: logsumexp(Tensor self, int[1] dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)
13361332
matches_jit_signature: True
13371333

1338-
- func: logsumexp_backward(Tensor grad, Tensor self, Tensor res, int[1] dim, bool keepdim) -> Tensor
1339-
matches_jit_signature: True
1340-
13411334
- func: margin_ranking_loss(Tensor input1, Tensor input2, Tensor target, float margin=0.0, int reduction=Mean) -> Tensor
13421335
matches_jit_signature: True
13431336

@@ -1410,9 +1403,6 @@
14101403
- func: mean(Tensor self, int[1] dim, *, ScalarType dtype, Tensor(a!) out) -> Tensor(a!)
14111404
matches_jit_signature: True
14121405

1413-
- func: sum_backward(Tensor grad, int[] sizes, int[] dims, bool keepdim) -> Tensor
1414-
matches_jit_signature: True
1415-
14161406
- func: median(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices)
14171407
matches_jit_signature: True
14181408
variants: function, method
@@ -1633,9 +1623,6 @@
16331623
matches_jit_signature: True
16341624
variants: method # This is method-only to match the previous tensor API. In the future we could make this a function too.
16351625

1636-
- func: permute_backwards(Tensor grad, int[] fwd_dims) -> Tensor
1637-
matches_jit_signature: True
1638-
16391626
- func: pixel_shuffle(Tensor self, int upscale_factor) -> Tensor
16401627
matches_jit_signature: True
16411628

@@ -1910,15 +1897,11 @@
19101897
variants: function, method
19111898
device_guard: False
19121899

1913-
- func: _safe_size(int[] sizes, int[] dim) -> int64_t
1914-
19151900
- func: slice(Tensor(a) self, int dim=0, int start=0, int end=9223372036854775807, int step=1) -> Tensor(a)
19161901
matches_jit_signature: True
19171902
variants: function, method
19181903
device_guard: False
19191904

1920-
- func: slice_backward(Tensor grad, int[] input_sizes, int64_t dim, int64_t start, int64_t end, int64_t step) -> Tensor
1921-
19221905
- func: slogdet(Tensor self) -> (Tensor, Tensor)
19231906
matches_jit_signature: True
19241907
variants: function, method
@@ -2009,11 +1992,6 @@
20091992
variants: function, method
20101993
device_guard: False
20111994

2012-
- func: unsqueeze_to(Tensor self, int[] sizes) -> Tensor
2013-
matches_jit_signature: True
2014-
2015-
- func: unsqueeze_to(Tensor self, int64_t dim, int[] sizes) -> Tensor
2016-
20171995
- func: squeeze_(Tensor(a!) self) -> Tensor(a!)
20181996
matches_jit_signature: True
20191997
variants: method
@@ -2310,12 +2288,6 @@
23102288
- func: var(Tensor self, int[1] dim, bool unbiased=True, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)
23112289
matches_jit_signature: True
23122290

2313-
- func: var_backward(Tensor grad, Tensor self, bool unbiased) -> Tensor
2314-
matches_jit_signature: True
2315-
2316-
- func: var_backward(Tensor grad, Tensor self, int[] dim, bool unbiased, bool keepdim) -> Tensor
2317-
matches_jit_signature: True
2318-
23192291
- func: view_as(Tensor self, Tensor other) -> Tensor
23202292
matches_jit_signature: True
23212293
variants: method

test/cpp/jit/test_misc.h

Lines changed: 3 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -848,24 +848,7 @@ void testDifferentiate(std::ostream& out = std::cout) {
848848

849849
auto grad_spec = differentiate(graph);
850850
std::vector<size_t> expected_captured_inputs = {0, 1};
851-
// With add/mul implemented using torchscript, we passes sizes of
852-
// self & other instead passing the tensors themselve.
853-
// The forward graph is now
854-
//graph(%0 : Float(2, 3, 4)
855-
// %1 : Float(2, 3, 4)) {
856-
// %2 : Float(2, 3, 4) = aten::mul(%0, %1)
857-
// %self_size.4 : int[] = aten::size(%0)
858-
// %other_size.4 : int[] = aten::size(%1)
859-
// %3 : Float(2, 3, 4) = aten::mul(%2, %0)
860-
// %self_size.2 : int[] = aten::size(%2)
861-
// %4 : int = prim::Constant[value=1]()
862-
// %7 : int[] = aten::size(%3)
863-
// %5 : Float(2, 3, 4) = aten::add(%3, %1, %4)
864-
// return (%5, %2, %self_size.4, %other_size.4, %self_size.2, %7);
865-
//}
866-
// Thus all the sizes info added in forward outputs are saved
867-
// in grad_spec.df_input_caputered_outputs.
868-
std::vector<size_t> expected_captured_outputs = {1, 2, 3, 4, 5};
851+
std::vector<size_t> expected_captured_outputs = {1, 2};
869852
std::vector<size_t> expected_input_vjps = {0, 1};
870853
std::vector<size_t> expected_output_vjps = {0, 1};
871854
ASSERT_EQ(grad_spec.f_real_outputs, 1);
@@ -897,29 +880,12 @@ void testDifferentiateWithRequiresGrad(std::ostream& out = std::cout) {
897880
PropagateInputShapes(graph);
898881
PropagateRequiresGrad(graph);
899882

900-
// With add/mul implemented using torchscript, we passes sizes of
901-
// self & other instead passing the tensors themselve.
902-
// The forward graph is now
903-
// graph(%0 : Float(*)
904-
// %1 : Float(*)) {
905-
// %2 : Float(*) = aten::mul(%1, %1)
906-
// %3 : int = prim::Constant[value=1]()
907-
// %4 : Float(*) = aten::add(%2, %1, %3)
908-
// %39 : int[] = aten::size(%0)
909-
// %6 : Float(*) = aten::add(%4, %0, %3)
910-
// %7 : Float(*) = aten::mul(%6, %0)
911-
// %self_size.2 : int[] = aten::size(%6)
912-
// %11 : int[] = aten::size(%7)
913-
// %9 : Float(*) = aten::add(%7, %1, %3)
914-
// return (%4, %9, %39, %6, %self_size.2, %11);
915-
// }
916-
917883
auto grad_spec = differentiate(graph);
918-
std::vector<size_t> expected_input_vjps = {1, 3}; // for e and %6 = (d + a)
884+
std::vector<size_t> expected_input_vjps = {1, 2}; // for e and %4 = (d + a)
919885
std::vector<size_t> expected_output_vjps = {0}; // only a requires grad
920886
ASSERT_EQ(grad_spec.f_real_outputs, 2);
921887
ASSERT_EQ(grad_spec.df_input_captured_inputs, std::vector<size_t>({0}));
922-
ASSERT_EQ(grad_spec.df_input_captured_outputs, std::vector<size_t>({2, 3, 4, 5}));
888+
ASSERT_EQ(grad_spec.df_input_captured_outputs, std::vector<size_t>({2, 3}));
923889
ASSERT_EQ(grad_spec.df_input_vjps, expected_input_vjps);
924890
ASSERT_EQ(grad_spec.df_output_vjps, expected_output_vjps);
925891
out << "testDifferentiateWithRequiresGrad\n";

test/expect/TestFuser.test_lstm_cuda-backward.expect

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -22,35 +22,35 @@ graph(%0 : Float(*, *),
2222
%forgetgate : Float(*, *),
2323
%cellgate : Float(*, *),
2424
%outgate : Float(*, *),
25-
%self_size.5 : int[],
26-
%other_size.5 : int[],
27-
%self_size.3 : int[],
28-
%other_size.3 : int[],
29-
%28 : int[],
30-
%29 : int[],
31-
%30 : Float(*, *),
32-
%self_size.1 : int[],
33-
%other_size.1 : int[]):
34-
%33 : int = prim::Constant[value=1]()
35-
%34 : Tensor = prim::FusionGroup_0(%outgate, %0, %30, %self_size.1)
36-
%grad_other.5 : Tensor, %36 : Tensor, %37 : Tensor, %38 : Tensor = prim::FusionGroup_1(%forgetgate, %9, %ingate, %cellgate, %1, %30, %0, %outgate, %other_size.5, %self_size.5, %28, %other_size.3, %self_size.3, %29, %other_size.1)
25+
%24 : int[],
26+
%25 : int[],
27+
%26 : Float(*, *)):
28+
%27 : int = prim::Constant[value=1]()
29+
%28 : int[] = aten::size(%outgate)
30+
%29 : int[] = aten::size(%26)
31+
%30 : int[] = aten::size(%ingate)
32+
%31 : int[] = aten::size(%cellgate)
33+
%32 : int[] = aten::size(%forgetgate)
34+
%33 : int[] = aten::size(%9)
35+
%34 : Tensor = prim::FusionGroup_0(%outgate, %0, %26, %28)
36+
%grad_other.5 : Tensor, %36 : Tensor, %37 : Tensor, %38 : Tensor = prim::FusionGroup_1(%forgetgate, %9, %ingate, %cellgate, %1, %26, %0, %outgate, %33, %32, %24, %31, %30, %25, %29)
3737
%39 : Tensor[] = prim::ListConstruct(%38, %36, %37, %34)
38-
%40 : Tensor = aten::cat(%39, %33)
38+
%40 : Tensor = aten::cat(%39, %27)
3939
%41 : Tensor = aten::_grad_sum_to_size(%40, %19)
4040
%42 : Tensor = aten::_grad_sum_to_size(%40, %17)
4141
%43 : Tensor = aten::_grad_sum_to_size(%40, %14)
4242
%44 : Tensor = aten::_grad_sum_to_size(%40, %15)
4343
%45 : Float(*, *) = aten::t(%13)
44-
%46 : Float(*, *) = aten::mm(%44, %45)
44+
%grad_self.7 : Float(*, *) = aten::mm(%44, %45)
4545
%47 : Float(*, *) = aten::t(%10)
46-
%48 : Float(*, *) = aten::mm(%47, %44)
47-
%grad_self.7 : Float(*, *) = aten::t(%48)
46+
%grad_mat2.1 : Float(*, *) = aten::mm(%47, %44)
47+
%grad_self.9 : Float(*, *) = aten::t(%grad_mat2.1)
4848
%50 : Float(*, *) = aten::t(%12)
49-
%51 : Float(*, *) = aten::mm(%43, %50)
49+
%grad_self.11 : Float(*, *) = aten::mm(%43, %50)
5050
%52 : Float(*, *) = aten::t(%11)
51-
%53 : Float(*, *) = aten::mm(%52, %43)
52-
%grad_self.9 : Float(*, *) = aten::t(%53)
53-
return (%grad_other.5, %41, %42, %46, %grad_self.7, %51, %grad_self.9)
51+
%grad_mat2.3 : Float(*, *) = aten::mm(%52, %43)
52+
%grad_self.13 : Float(*, *) = aten::t(%grad_mat2.3)
53+
return (%grad_other.5, %41, %42, %grad_self.7, %grad_self.9, %grad_self.11, %grad_self.13)
5454
with prim::FusionGroup_0 = graph(%0 : Float(*, *),
5555
%1 : Float(*, *),
5656
%2 : Float(*, *),

test/expect/TestFuser.test_lstm_cuda-forward.expect

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -28,16 +28,14 @@ with prim::DifferentiableGraph_0 = graph(%0 : Float(*, *),
2828
%17 : Tensor, %18 : Tensor, %19 : Tensor, %20 : Tensor = prim::ListUnpack(%16)
2929
%21 : int[] = prim::BroadcastSizes(%11, %12)
3030
%22 : int[] = prim::BroadcastSizes(%21, %13)
31-
%other_size.6 : int[] = aten::size(%0)
32-
%hy : Float(*, *), %25 : Float(*, *), %cy : Float(*, *), %outgate.1 : Float(*, *), %cellgate.1 : Float(*, *), %forgetgate.1 : Float(*, *), %ingate.1 : Float(*, *) = prim::FusionGroup_0(%0, %20, %19, %18, %17)
33-
%31 : int[] = aten::size(%25)
34-
%32 : int[] = aten::size(%outgate.1)
35-
%33 : int[] = aten::size(%cellgate.1)
36-
%34 : int[] = aten::size(%forgetgate.1)
37-
%35 : int[] = aten::size(%ingate.1)
38-
%36 : int[] = prim::BroadcastSizes(%34, %other_size.6)
39-
%37 : int[] = prim::BroadcastSizes(%35, %33)
40-
return (%hy, %cy, %7, %9, %11, %12, %21, %13, %22, %14, %ingate.1, %forgetgate.1, %cellgate.1, %outgate.1, %34, %other_size.6, %35, %33, %36, %37, %25, %32, %31)
31+
%hy : Float(*, *), %24 : Float(*, *), %cy : Float(*, *), %outgate.1 : Float(*, *), %cellgate.1 : Float(*, *), %forgetgate.1 : Float(*, *), %ingate.1 : Float(*, *) = prim::FusionGroup_0(%0, %20, %19, %18, %17)
32+
%30 : int[] = aten::size(%0)
33+
%31 : int[] = aten::size(%cellgate.1)
34+
%32 : int[] = aten::size(%forgetgate.1)
35+
%33 : int[] = aten::size(%ingate.1)
36+
%34 : int[] = prim::BroadcastSizes(%32, %30)
37+
%35 : int[] = prim::BroadcastSizes(%33, %31)
38+
return (%hy, %cy, %7, %9, %11, %12, %21, %13, %22, %14, %ingate.1, %forgetgate.1, %cellgate.1, %outgate.1, %34, %35, %24)
4139
with prim::FusionGroup_0 = graph(%0 : Float(*, *),
4240
%1 : Tensor,
4341
%2 : Tensor,

0 commit comments

Comments
 (0)