Skip to content

Commit 0f9807e

Browse files
James Reedfacebook-github-bot
authored andcommitted
Enable addmm fusion for ONNX export only (pytorch#12538)
Summary: There's some action at a distance issues and not having this is disabling quantization in C2 for prod use cases ref T34831022 Pull Request resolved: pytorch#12538 Differential Revision: D10302931 Pulled By: jamesr66a fbshipit-source-id: 700dc8c5c4297e942171992266ffb67b815be754
1 parent 7b0f5d6 commit 0f9807e

File tree

8 files changed

+60
-21
lines changed

8 files changed

+60
-21
lines changed
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
graph(%0 : Double(*, *)
2+
%1 : Double(*, *)
3+
%2 : Double(*, *)) {
4+
%3 : int = prim::Constant[value=1]()
5+
%4 : Double(*, *) = aten::mm(%0, %1)
6+
%5 : Double(*, *) = aten::add(%4, %2, %3)
7+
return (%5);
8+
}
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
ModelProto {
2+
producer_name: "pytorch"
3+
domain: ""
4+
doc_string: ""
5+
graph:
6+
GraphProto {
7+
name: "torch-jit-export"
8+
inputs: [{name: "0", type:Tensor dims: 3 4},{name: "1", type:Tensor dims: 4 5},{name: "2", type:Tensor dims: 3 5}]
9+
outputs: [{name: "3", type:Tensor dims: 3 5}]
10+
initializers: []
11+
nodes: [
12+
Node {type: "Gemm", inputs: [0,1,2], outputs: [3], attributes: [{ name: 'alpha', type: float, value: 1},{ name: 'beta', type: float, value: 1}]}
13+
]
14+
}
15+
opset_import: [OperatorSetIdProto { domain: }],
16+
}

test/expect/TestScript.test_onnx_export_speculate-f2.expect

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -24,25 +24,21 @@ ModelProto {
2424
GraphProto {
2525
name: "torch-jit-export2"
2626
inputs: []
27-
outputs: [{name: "11", type:Tensor dims: 1 20}]
27+
outputs: [{name: "9", type:Tensor dims: 1 20}]
2828
initializers: []
2929
nodes: [
30-
Node {type: "Constant", inputs: [], outputs: [9], attributes: [{ name: 'value', type: tensor, value:TensorProto shape: [1]}]},
31-
Node {type: "Gemm", inputs: [3,1,9], outputs: [10], attributes: [{ name: 'alpha', type: float, value: 1},{ name: 'beta', type: float, value: 0},{ name: 'transB', type: int, value: 1}]},
32-
Node {type: "Add", inputs: [2,10], outputs: [11], attributes: []}
30+
Node {type: "Gemm", inputs: [3,1,2], outputs: [9], attributes: [{ name: 'alpha', type: float, value: 1},{ name: 'beta', type: float, value: 1},{ name: 'transB', type: int, value: 1}]}
3331
]
3432
}
3533

3634
},{ name: 'else_branch', type: graph, value:
3735
GraphProto {
3836
name: "torch-jit-export3"
3937
inputs: []
40-
outputs: [{name: "14", type:Tensor dims: 1 20}]
38+
outputs: [{name: "10", type:Tensor dims: 1 20}]
4139
initializers: []
4240
nodes: [
43-
Node {type: "Constant", inputs: [], outputs: [12], attributes: [{ name: 'value', type: tensor, value:TensorProto shape: [1]}]},
44-
Node {type: "Gemm", inputs: [3,1,12], outputs: [13], attributes: [{ name: 'alpha', type: float, value: 1},{ name: 'beta', type: float, value: 0},{ name: 'transB', type: int, value: 1}]},
45-
Node {type: "Add", inputs: [2,13], outputs: [14], attributes: []}
41+
Node {type: "Gemm", inputs: [3,1,2], outputs: [10], attributes: [{ name: 'alpha', type: float, value: 1},{ name: 'beta', type: float, value: 1},{ name: 'transB', type: int, value: 1}]}
4642
]
4743
}
4844

@@ -54,12 +50,10 @@ ModelProto {
5450
GraphProto {
5551
name: "torch-jit-export4"
5652
inputs: []
57-
outputs: [{name: "17", type:Tensor dims: 1 20}]
53+
outputs: [{name: "11", type:Tensor dims: 1 20}]
5854
initializers: []
5955
nodes: [
60-
Node {type: "Constant", inputs: [], outputs: [15], attributes: [{ name: 'value', type: tensor, value:TensorProto shape: [1]}]},
61-
Node {type: "Gemm", inputs: [3,1,15], outputs: [16], attributes: [{ name: 'alpha', type: float, value: 1},{ name: 'beta', type: float, value: 0},{ name: 'transB', type: int, value: 1}]},
62-
Node {type: "Add", inputs: [2,16], outputs: [17], attributes: []}
56+
Node {type: "Gemm", inputs: [3,1,2], outputs: [11], attributes: [{ name: 'alpha', type: float, value: 1},{ name: 'beta', type: float, value: 1},{ name: 'transB', type: int, value: 1}]}
6357
]
6458
}
6559

test/test_jit.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7222,6 +7222,21 @@ def elif_test(niter : int):
72227222

72237223
self.checkScript(code, (101,), name='elif_test', outputs=3028)
72247224

7225+
def test_addmm_fusion(self):
7226+
class AddmmWrapper(torch.nn.Module):
7227+
def forward(self, x, y, c):
7228+
return torch.mm(x, y) + c
7229+
7230+
# Test addmm fusion is disabled for normal Jit
7231+
x, y, c = torch.rand(3, 4), torch.rand(4, 5), torch.rand(3, 5)
7232+
f = io.BytesIO()
7233+
pretty = torch.onnx.export_to_pretty_string(AddmmWrapper(), (x, y, c), f)
7234+
self.assertExpected(pretty, 'onnx')
7235+
7236+
jit_trace = torch.jit.trace(AddmmWrapper(), (x, y, c))
7237+
ge_graph = jit_trace.__getattr__('forward').graph_for(x, y, c)
7238+
self.assertExpectedGraph(ge_graph, 'jit')
7239+
72257240
def test_weak_script_function(self):
72267241
outer_var = 10
72277242
outer_var2 = 11

torch/csrc/jit/init.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ void initJITBindings(PyObject *module) {
9090
return EliminateCommonSubexpression(g); // overload resolution
9191
})
9292
.def("_jit_pass_constant_pooling", ConstantPooling)
93-
.def("_jit_pass_peephole", PeepholeOptimize)
93+
.def("_jit_pass_peephole", PeepholeOptimize, py::arg("graph"), py::arg("addmm_fusion_enabled") = false)
9494
.def("_jit_pass_canonicalize", [](const std::shared_ptr<Graph>& g) {
9595
return Canonicalize(g);
9696
})

torch/csrc/jit/passes/peephole.cpp

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,19 @@ namespace torch { namespace jit {
1414
// - Simply x.t().t() to x
1515
//
1616
// TODO: Decide what kind of fixed point strategy we will have
17-
void PeepholeOptimize(Block * block) {
17+
//
18+
// The parameter `addmm_fusion_enabled` exists because, as it is today, fusing
19+
// add + mm has no benefit within PyTorch running ATen ops. However, we rely on
20+
// seeing the fused version of addmm for ONNX export, since after ONNX translation
21+
// we would see redundant Gemm ops with sub-optimal inputs. This flag is exposed
22+
// so that ONNX export can pass `true` to get the fused behavior, but normal
23+
// JIT peephole optimization is left alone.
24+
void PeepholeOptimize(Block * block, bool addmm_fusion_enabled) {
1825
for (auto it = block->nodes().begin(); it != block->nodes().end(); ++it) {
1926
auto* node = *it;
2027

2128
for (Block * sub_block : node->blocks()) {
22-
PeepholeOptimize(sub_block);
29+
PeepholeOptimize(sub_block, addmm_fusion_enabled);
2330
}
2431

2532
// XXX: remember that if you want to simplify an expression by combining multiple nodes
@@ -60,7 +67,6 @@ void PeepholeOptimize(Block * block) {
6067
// and because it works out of place on C, we're only trading off an explicit add for
6168
// a copy inside the addmm function. Note that it doesn't even result in fewer reads,
6269
// because mm won't even load C (because beta == 0 for it).
63-
static constexpr bool addmm_fusion_enabled = false;
6470
if (addmm_fusion_enabled && node->get<at::Scalar>(attr::alpha).value().toDouble() == 1.) {
6571
// Look for mm from both sides of the add
6672
for (size_t mm_side = 0; mm_side < 2; mm_side++) {
@@ -123,8 +129,8 @@ void PeepholeOptimize(Block * block) {
123129
}
124130
}
125131

126-
void PeepholeOptimize(std::shared_ptr<Graph>& graph) {
127-
PeepholeOptimize(graph->block());
132+
void PeepholeOptimize(std::shared_ptr<Graph>& graph, bool addmm_fusion_enabled) {
133+
PeepholeOptimize(graph->block(), addmm_fusion_enabled);
128134
// Eliminate dead code created by any peephole passes we've just done
129135
EliminateDeadCode(graph->block());
130136
}

torch/csrc/jit/passes/peephole.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,6 @@
44

55
namespace torch { namespace jit {
66

7-
TORCH_API void PeepholeOptimize(std::shared_ptr<Graph>& graph);
7+
TORCH_API void PeepholeOptimize(std::shared_ptr<Graph>& graph, bool addmm_fusion_enabled=false);
88

99
}}

torch/onnx/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ def _optimize_graph(graph, operator_export_type):
138138
torch._C._jit_pass_canonicalize_ops(graph)
139139
torch._C._jit_pass_lint(graph)
140140

141-
torch._C._jit_pass_peephole(graph)
141+
torch._C._jit_pass_peephole(graph, True)
142142
torch._C._jit_pass_lint(graph)
143143

144144
# onnx only supports tensors, but 1 / 2 = 0.5 and tensor(1) / tensor(2) = 0
@@ -147,7 +147,7 @@ def _optimize_graph(graph, operator_export_type):
147147
torch._C._jit_pass_erase_number_types(graph)
148148
# onnx does not support tuples, so try to remove them
149149
torch._C._jit_pass_lower_all_tuples(graph)
150-
torch._C._jit_pass_peephole(graph)
150+
torch._C._jit_pass_peephole(graph, True)
151151
torch._C._jit_pass_lint(graph)
152152

153153
if operator_export_type != OperatorExportTypes.RAW:

0 commit comments

Comments
 (0)