Skip to content

Commit 4d703d0

Browse files
jjsjann123facebook-github-bot
authored andcommittedFeb 5, 2021
Linear autodiff revert revert (pytorch#51613)
Summary: patch PR pytorch#50856 and rollbak the revert D26105797 (pytorch@e488e3c) Pull Request resolved: pytorch#51613 Reviewed By: mruberry Differential Revision: D26253999 Pulled By: ngimel fbshipit-source-id: a20b1591de06dd277e4cd95542e3291a2f5a252c
1 parent 6dcbf39 commit 4d703d0

File tree

12 files changed

+190
-57
lines changed

12 files changed

+190
-57
lines changed
 

‎aten/src/ATen/native/Linear.cpp

-14
Original file line numberDiff line numberDiff line change
@@ -13,20 +13,6 @@
1313

1414
namespace at { namespace native {
1515

16-
// in order to dispatch mkldnn linear to addmm which can be removed after linear ported.
17-
Tensor& mkldnn_addmm_wraper_out(Tensor& result, const Tensor& bias,
18-
const Tensor& input, const Tensor& weight, Scalar beta, Scalar alpha) {
19-
TORCH_CHECK(false,
20-
"mkldnn_addmm_wraper_out: in-place mkldnn operations are not supported yet");
21-
}
22-
23-
Tensor mkldnn_addmm_wraper(const Tensor& bias, const Tensor& input,
24-
const Tensor& weight, Scalar beta, Scalar alpha) {
25-
TORCH_CHECK(input.dim() == 2,
26-
"mkldnn_addmm_wraper: input needs to has dim 2, input dim ", input.dim());
27-
return at::mkldnn_linear(input, weight.t(), bias);
28-
}
29-
3016
Tensor linear(const Tensor& input, const Tensor& weight, const Tensor& bias) {
3117
if (input.is_mkldnn()) {
3218
return at::mkldnn_linear(input, weight, bias);

‎aten/src/ATen/native/native_functions.yaml

+1-3
Original file line numberDiff line numberDiff line change
@@ -2172,7 +2172,7 @@
21722172
use_c10_dispatcher: hacky_wrapper_for_legacy_signatures
21732173
python_module: nn
21742174

2175-
- func: mkldnn_linear(Tensor input, Tensor weight, Tensor? bias=None) -> Tensor
2175+
- func: mkldnn_linear(Tensor self, Tensor weight, Tensor? bias=None) -> Tensor
21762176
use_c10_dispatcher: hacky_wrapper_for_legacy_signatures
21772177
python_module: nn
21782178
dispatch:
@@ -4331,7 +4331,6 @@
43314331
CUDA: addmm_out_cuda
43324332
SparseCPU: addmm_out_sparse_dense_cpu
43334333
SparseCUDA: addmm_out_sparse_dense_cuda
4334-
MkldnnCPU: mkldnn_addmm_wraper_out
43354334

43364335
- func: addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor
43374336
variants: function, method
@@ -4340,7 +4339,6 @@
43404339
CUDA: addmm_cuda
43414340
SparseCPU: addmm_sparse_dense_cpu
43424341
SparseCUDA: addmm_sparse_dense_cuda
4343-
MkldnnCPU: mkldnn_addmm_wraper
43444342

43454343
- func: addmm_(Tensor(a!) self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor(a!)
43464344
variants: method

‎test/backward_compatibility/check_backward_compatibility.py

+1
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@
7070
("aten::_foreach_div", datetime.date(2021, 2, 25)),
7171
("aten::_foreach_div_", datetime.date(2021, 2, 25)),
7272
("aten::_foreach_addcdiv", datetime.date(2021, 2, 25)),
73+
("aten::mkldnn_linear", datetime.date(2021, 3, 2)),
7374
]
7475

7576
def allow_listed(schema, allow_list):

‎test/jit/test_onnx_export.py

+31
Original file line numberDiff line numberDiff line change
@@ -334,6 +334,37 @@ def transpose(x):
334334
(torch.ones(1, 10, dtype=torch.float), ),
335335
None, verbose=False, example_outputs=outputs_f2)
336336

337+
def test_onnx_export_preprocess_decompose_linear(self):
338+
def t(x, weight, bias):
339+
return torch.nn.functional.linear(x, weight, bias)
340+
341+
foo = torch.jit.script(t)
342+
foo(torch.zeros(2, 4), torch.randn(3, 4), torch.randn(3))
343+
# run it twice in case we need to remove profiling nodes
344+
graph = foo.graph_for(
345+
torch.zeros(2, 4), torch.randn(3, 4), torch.randn(3))
346+
347+
nodes = []
348+
for n in graph.nodes():
349+
nodes.append(n.kind())
350+
self.assertEqual(nodes, ['aten::linear'])
351+
torch._C._jit_pass_onnx_preprocess(graph)
352+
353+
nodes = []
354+
for n in graph.nodes():
355+
nodes.append(n.kind())
356+
for b in n.blocks():
357+
nodes_b = []
358+
for n_n in b.nodes():
359+
nodes_b.append(n_n.kind())
360+
nodes.append(nodes_b)
361+
362+
self.assertEqual(
363+
nodes,
364+
['aten::dim', 'prim::Constant', 'aten::eq', 'prim::If',
365+
['prim::Constant', 'aten::t', 'aten::addmm'],
366+
['prim::Constant', 'aten::t', 'aten::matmul', 'aten::add']])
367+
337368
def test_onnx_export_shape_reshape(self):
338369
class Foo(torch.nn.Module):
339370
def forward(self, x):

‎test/jit/test_remove_mutation.py

+1-22
Original file line numberDiff line numberDiff line change
@@ -2,40 +2,19 @@
22
import sys
33

44
import torch
5-
from torch.nn import functional as F
65
from torch.testing import FileCheck
76

87
# Make the helper files in test/ importable
98
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
109
sys.path.append(pytorch_test_dir)
11-
from torch.testing._internal.jit_utils import JitTestCase, freeze_rng_state
10+
from torch.testing._internal.jit_utils import JitTestCase
1211

1312
if __name__ == '__main__':
1413
raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
1514
"\tpython test/test_jit.py TESTNAME\n\n"
1615
"instead.")
1716

1817
class TestRemoveMutation(JitTestCase):
19-
def test_lower_linear(self):
20-
# linear is one of main use cases of removing mutation so add test so it doesnt regress
21-
@torch.jit.script
22-
def foo(x):
23-
return F.linear(x, torch.randn(20, 20), torch.randn(20))
24-
25-
self.run_pass('inline', foo.graph)
26-
self.run_pass('peephole', foo.graph)
27-
self.run_pass('constant_propagation', foo.graph)
28-
FileCheck().check("aten::add_").run(foo.graph)
29-
input = torch.randn(20, 20)
30-
with freeze_rng_state():
31-
out1 = foo(input)
32-
33-
self.run_pass('remove_mutation', foo.graph)
34-
FileCheck().check_not("aten::add_").run(foo.graph)
35-
with freeze_rng_state():
36-
out2 = foo(input)
37-
self.assertEqual(out1, out2)
38-
3918
def test_aten_inplace(self):
4019
def test_not_new_alias(x):
4120
y = x[0]

‎test/test_autograd.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -3191,10 +3191,8 @@ def test_profiler_shapes(self):
31913191
print(prof.function_events)
31923192

31933193
top_level_expected_events_and_shapes = [
3194-
(None, [[30, 20]]),
3195-
('aten::addmm', [[30], [128, 20], [20, 30], [], []]),
3196-
(None, [[40, 30]]),
3197-
('aten::addmm', [[40], [128, 30], [30, 40], [], []])
3194+
('aten::linear', [[128, 20], [30, 20], [30]]),
3195+
('aten::linear', [[128, 30], [40, 30], [40]])
31983196
]
31993197

32003198
expected_iter = iter(top_level_expected_events_and_shapes)

‎test/test_jit.py

+53
Original file line numberDiff line numberDiff line change
@@ -10665,6 +10665,59 @@ def randint():
1066510665
FileCheck().check("Double(*, *, requires_grad=0, device=cpu)") \
1066610666
.check_not("Float(*, *, requires_grad=0, device=cpu)").run(randint.graph_for())
1066710667

10668+
def test_linear_grad(self):
10669+
with enable_profiling_mode_for_profiling_tests():
10670+
def t(x: torch.Tensor, w: torch.Tensor, b: Optional[torch.Tensor]):
10671+
return torch.nn.functional.linear(x, w, b)
10672+
10673+
x_init = torch.randn(4, 2)
10674+
w_init = torch.randn(3, 2)
10675+
b_init = torch.randn(3)
10676+
grad = torch.randn(4, 3)
10677+
10678+
with disable_autodiff_subgraph_inlining():
10679+
# script module
10680+
jit_t = torch.jit.script(t)
10681+
10682+
x = x_init.detach().requires_grad_()
10683+
w = w_init.detach().requires_grad_()
10684+
b = b_init.detach().requires_grad_()
10685+
x_ref = x_init.detach().requires_grad_()
10686+
w_ref = w_init.detach().requires_grad_()
10687+
b_ref = b_init.detach().requires_grad_()
10688+
10689+
# profiling/optimization runs
10690+
jit_o = jit_t(x, w, b)
10691+
jit_o.backward(grad)
10692+
jit_o = jit_t(x, w, b)
10693+
jit_o.backward(grad)
10694+
10695+
x.grad.zero_()
10696+
w.grad.zero_()
10697+
b.grad.zero_()
10698+
jit_o = jit_t(x, w, b)
10699+
jit_o.backward(grad)
10700+
o = t(x_ref, w_ref, b_ref)
10701+
o.backward(grad)
10702+
10703+
self.assertEqual(jit_o, o)
10704+
self.assertEqual(x.grad, x_ref.grad)
10705+
self.assertEqual(w.grad, w_ref.grad)
10706+
self.assertEqual(b.grad, b_ref.grad)
10707+
10708+
x.grad.zero_()
10709+
w.grad.zero_()
10710+
x_ref.grad.zero_()
10711+
w_ref.grad.zero_()
10712+
jit_o = jit_t(x, w, None)
10713+
jit_o.backward(grad)
10714+
o = t(x_ref, w_ref, None)
10715+
o.backward(grad)
10716+
10717+
self.assertEqual(jit_o, o)
10718+
self.assertEqual(x.grad, x_ref.grad)
10719+
self.assertEqual(w.grad, w_ref.grad)
10720+
1066810721
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "the profiling version of test_rand")
1066910722
def test_rand_profiling(self):
1067010723
def test_rand():

‎tools/autograd/derivatives.yaml

+6-3
Original file line numberDiff line numberDiff line change
@@ -187,9 +187,9 @@
187187
tensor2: handle_r_to_c(tensor2.scalar_type(), grad * (tensor1 * value).conj())
188188

189189
- name: addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor
190-
self: "grad.is_mkldnn() ? grad.to_dense() : maybe_multiply(grad, beta.conj())"
191-
mat1: "grad.is_mkldnn() ? mkldnn_linear_backward_input(mat1.sizes(), grad, mat2.t()) : mm_mat1_backward(grad, mat2, mat1.sizes(), mat1.strides(), alpha)"
192-
mat2: "grad.is_mkldnn() ? (std::get<0>(mkldnn_linear_backward_weights(grad, mat1, mat2.t(), true))).t() : mm_mat2_backward(grad, mat1, mat2.sizes(), mat2.strides(), alpha)"
190+
self: maybe_multiply(grad, beta.conj())
191+
mat1: mm_mat1_backward(grad, mat2, mat1.sizes(), mat1.strides(), alpha)
192+
mat2: mm_mat2_backward(grad, mat1, mat2.sizes(), mat2.strides(), alpha)
193193

194194
- name: _sparse_addmm(Tensor self, Tensor sparse, Tensor dense, *, Scalar beta=1, Scalar alpha=1) -> Tensor
195195
self: maybe_multiply(grad, beta)
@@ -1888,6 +1888,9 @@
18881888
- name: mkldnn_convolution_backward(Tensor self, Tensor grad_output, Tensor weight, int[] padding, int[] stride, int[] dilation, int groups, bool[3] output_mask) -> (Tensor, Tensor, Tensor)
18891889
grad_output, self, weight: _convolution_double_backward(grads[0], grads[1], grads[2], grad_output, weight, self, stride, padding, dilation, false, std::vector<int64_t>(padding.size(), 0), groups, false, false, false, false, grad_input_mask)
18901890

1891+
- name: mkldnn_linear(Tensor self, Tensor weight, Tensor? bias=None) -> Tensor
1892+
self, weight, bias: mkldnn_linear_backward(self, grad, weight, grad_input_mask)
1893+
18911894
# fft
18921895
- name: _fft_r2c(Tensor self, int[] dim, int normalization, bool onesided) -> Tensor
18931896
self: fft_r2c_backward(grad, dim, normalization, onesided, self.size(dim.back()))

‎torch/csrc/jit/passes/onnx/preprocess_for_onnx.cpp

+75
Original file line numberDiff line numberDiff line change
@@ -217,9 +217,84 @@ static void fuseListAndListUnpack(Block* b) {
217217
}
218218
}
219219

220+
static void decomposeLinear(Block* b) {
221+
std::vector<Node*> linear_nodes;
222+
for (auto it = b->nodes().begin(), end = b->nodes().end(); it != end; ++it) {
223+
for (auto* child_block : it->blocks()) {
224+
decomposeLinear(child_block);
225+
}
226+
if (it->kind() == aten::linear) {
227+
linear_nodes.push_back(*it);
228+
}
229+
}
230+
for (Node* node : linear_nodes) {
231+
auto g = b->owningGraph();
232+
233+
if (node->inputs()[2]->mustBeNone()) {
234+
auto t_weight_n =
235+
g->create(aten::t, {node->inputs()[1]}, 1)->insertBefore(node);
236+
auto matmul_n =
237+
g->create(aten::matmul, {node->inputs()[0], t_weight_n->output()}, 1)
238+
->insertBefore(node);
239+
node->output()->replaceAllUsesWith(matmul_n->output());
240+
node->destroy();
241+
} else {
242+
auto dim_n =
243+
g->create(aten::dim, {node->inputs()[0]}, 1)->insertBefore(node);
244+
auto const_2 = g->insertConstant(IValue(2));
245+
const_2->node()->moveBefore(node);
246+
auto eq_n = g->create(aten::eq, {dim_n->output(), const_2}, 1)
247+
->insertBefore(node);
248+
249+
auto if_n = g->create(prim::If, {eq_n->output()}, 1)->insertBefore(node);
250+
251+
auto true_block = if_n->addBlock();
252+
auto false_block = if_n->addBlock();
253+
254+
{
255+
WithInsertPoint guard(true_block->return_node());
256+
auto const_1 = g->insertConstant(IValue(1.0));
257+
auto t_weight_n = g->create(aten::t, {node->inputs()[1]}, 1)
258+
->insertBefore(true_block->return_node());
259+
auto addmm_n = g->create(
260+
aten::addmm,
261+
{node->inputs()[2],
262+
node->inputs()[0],
263+
t_weight_n->output(),
264+
const_1,
265+
const_1},
266+
1)
267+
->insertBefore(true_block->return_node());
268+
true_block->registerOutput(addmm_n->output());
269+
}
270+
271+
{
272+
WithInsertPoint guard(false_block->return_node());
273+
auto const_1 = g->insertConstant(IValue(1.0));
274+
auto t_weight_n = g->create(aten::t, {node->inputs()[1]}, 1)
275+
->insertBefore(false_block->return_node());
276+
auto matmul_n =
277+
g->create(
278+
aten::matmul, {node->inputs()[0], t_weight_n->output()}, 1)
279+
->insertBefore(false_block->return_node());
280+
auto add_n =
281+
g->create(
282+
aten::add, {matmul_n->output(), node->inputs()[2], const_1}, 1)
283+
->insertBefore(false_block->return_node());
284+
false_block->registerOutput(add_n->output());
285+
}
286+
node->output()->replaceAllUsesWith(if_n->output());
287+
node->destroy();
288+
}
289+
}
290+
}
291+
220292
} // namespace
221293

222294
void PreprocessForONNX(std::shared_ptr<Graph>& graph) {
295+
GRAPH_DEBUG("priot to decompose linear", graph);
296+
decomposeLinear(graph->block());
297+
GRAPH_DEBUG("after decompose linear", graph);
223298
FuseWithListUnpack(graph->block());
224299
ReplaceAddWithConcat(graph->block());
225300
fuseListAndListUnpack(graph->block());

‎torch/csrc/jit/runtime/symbolic_script.cpp

+17
Original file line numberDiff line numberDiff line change
@@ -403,6 +403,23 @@ const std::vector<std::string> functions = {
403403
return grad_self, grad_other
404404
405405
return torch.matmul(self, other), backward
406+
407+
def linear(input : Tensor,
408+
weight : Tensor,
409+
bias : Optional[Tensor]):
410+
result = torch.linear(input, weight, bias)
411+
412+
def backward(grad_output):
413+
if bias is not None:
414+
grad_bias = grad_output._grad_sum_to_size(bias.size())
415+
else:
416+
grad_bias = None
417+
418+
weight_size = weight.size()
419+
grad_input = torch.matmul(grad_output, weight)
420+
grad_weight = torch.matmul(grad_output.reshape(-1, weight_size[0]).t(), input.reshape(-1, weight_size[1]))
421+
return grad_input, grad_weight, grad_bias
422+
return result, backward
406423
)",
407424
R"(
408425
def addcmul(self,

‎torch/nn/functional.py

+1-9
Original file line numberDiff line numberDiff line change
@@ -1750,15 +1750,7 @@ def linear(input: Tensor, weight: Tensor, bias: Optional[Tensor] = None) -> Tens
17501750
"""
17511751
if has_torch_function_variadic(input, weight):
17521752
return handle_torch_function(linear, (input, weight), input, weight, bias=bias)
1753-
if input.dim() == 2 and bias is not None:
1754-
# fused op is marginally faster
1755-
ret = torch.addmm(bias, input, weight.t())
1756-
else:
1757-
output = input.matmul(weight.t())
1758-
if bias is not None:
1759-
output += bias
1760-
ret = output
1761-
return ret
1753+
return torch._C._nn.linear(input, weight, bias)
17621754

17631755

17641756
def bilinear(input1: Tensor, input2: Tensor, weight: Tensor, bias: Optional[Tensor] = None) -> Tensor:

‎torch/testing/_internal/jit_metaprogramming_utils.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -103,8 +103,8 @@
103103
('tanh', (S, S, S), (), '', (True,)),
104104
('sigmoid', (S, S, S), (), '', (True,)),
105105
('log_softmax', (S, S, S), (0,), '', (True,)),
106-
('linear', (S, S), ((M, S),), '', (True, ['aten::t', 'aten::matmul'])),
107-
('linear', (S, S), ((M, S), (M,)), 'addmm', (True, ['aten::add', 'aten::mm'])),
106+
('linear', (S, S), ((M, S),), '', (True, ['aten::linear'])),
107+
('linear', (S, S), ((M, S), (M,)), 'addmm', (True, ['aten::linear'])),
108108
('bilinear', (S, S, S), ((S, S, M), torch.zeros(M, S, M),),),
109109
('embedding', torch.tensor([[1, 2, 4, 5], [4, 3, 2, 5]]), (torch.rand(6, 3), ), '', (True,)),
110110
('embedding_bag', torch.tensor([1, 2, 4, 2]), (torch.rand(5, 3), torch.tensor([0, 4]),),),

0 commit comments

Comments
 (0)
Please sign in to comment.