Skip to content

Commit

Permalink
[JIT] Optimize before inlining (pytorch#35562)
Browse files Browse the repository at this point in the history
Summary:
Resubmit of pytorch#35424, only this time I run optimizations in the right order so the PR description is actually true.

This speeds up the inlining pass of FairSeq model from 180s -> 13s, and MaskRCNN model from 5s -> 1.5s.
Pull Request resolved: pytorch#35562

Differential Revision: D20738922

Pulled By: eellison

fbshipit-source-id: 1439cf9d1f0bc780e2d64a744694f8b3b7ba4b70
  • Loading branch information
Elias Ellison authored and ashish committed Apr 13, 2020
1 parent 136cd9d commit 6c722d3
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 31 deletions.
46 changes: 18 additions & 28 deletions test/expect/TestTensorBoard.test_pytorch_graph.expect
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ node {
node {
name: "output/output.1"
op: "IO Node"
input: "myLinear/Linear[l]/22"
input: "myLinear/Linear[l]/21"
attr {
key: "_output_shapes"
value {
Expand All @@ -50,7 +50,17 @@ node {
}
}
node {
name: "myLinear/Linear[l]/bias/17"
name: "myLinear/Linear[l]/17"
op: "prim::Constant"
attr {
key: "attr"
value {
s: "{ value : 1}"
}
}
}
node {
name: "myLinear/Linear[l]/bias/18"
op: "prim::GetAttr"
input: "myLinear/Linear[l]/weight/14"
attr {
Expand All @@ -61,7 +71,7 @@ node {
}
}
node {
name: "myLinear/Linear[l]/weight/18"
name: "myLinear/Linear[l]/weight/19"
op: "prim::GetAttr"
input: "myLinear/Linear[l]/weight/14"
attr {
Expand All @@ -72,9 +82,9 @@ node {
}
}
node {
name: "myLinear/Linear[l]/19"
name: "myLinear/Linear[l]/20"
op: "aten::t"
input: "myLinear/Linear[l]/weight/18"
input: "myLinear/Linear[l]/weight/19"
attr {
key: "_output_shapes"
value {
Expand All @@ -97,34 +107,14 @@ node {
}
}
}
node {
name: "myLinear/Linear[l]/20"
op: "prim::Constant"
attr {
key: "attr"
value {
s: "{ value : 1}"
}
}
}
node {
name: "myLinear/Linear[l]/21"
op: "prim::Constant"
attr {
key: "attr"
value {
s: "{ value : 1}"
}
}
}
node {
name: "myLinear/Linear[l]/22"
op: "aten::addmm"
input: "myLinear/Linear[l]/bias/17"
input: "myLinear/Linear[l]/bias/18"
input: "input/input"
input: "myLinear/Linear[l]/19"
input: "myLinear/Linear[l]/20"
input: "myLinear/Linear[l]/21"
input: "myLinear/Linear[l]/17"
input: "myLinear/Linear[l]/17"
attr {
key: "_output_shapes"
value {
Expand Down
14 changes: 13 additions & 1 deletion test/test_jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -5089,7 +5089,7 @@ def fn2(x):
with self.capture_stdout():
traced = torch.jit.trace(fn, [torch.ones(2, 2)])

FileCheck().check("goodbye").check("hello").run(traced.graph)
FileCheck().check("goodbye").run(traced.graph)

def test_big_int_literals(self):
def ok():
Expand Down Expand Up @@ -11360,6 +11360,18 @@ def check(mod):
imported = self.getExportImportCopy(traced)
check(imported)

def test_inlining_cleanup(self):
def foo(x):
return F.linear(x, x)

@torch.jit.script
def fee(x):
return foo(x)

# inlining optimizations should have cleaned up linear if statement
self.run_pass("inline", fee.graph)
FileCheck().check_not("prim::If").run(fee.graph)

def test_trace_export_fns_recursive(self):
class Foo(torch.nn.Module):
def __init__(self):
Expand Down
12 changes: 10 additions & 2 deletions torch/csrc/jit/api/function_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
#include <torch/csrc/jit/passes/inliner.h>

#include <torch/csrc/jit/frontend/error_report.h>
#include <torch/csrc/jit/passes/constant_pooling.h>
#include <torch/csrc/jit/passes/constant_propagation.h>
#include <torch/csrc/jit/passes/peephole.h>

namespace torch {
namespace jit {
Expand Down Expand Up @@ -66,9 +69,14 @@ const c10::FunctionSchema& GraphFunction::getSchema() const {
}

void preoptimizeGraph(std::shared_ptr<Graph>& graph) {
// TODO: Invoke cleanup passes before and after inlining to reduce amount of
// code we're copying.
Inline(*graph);
// Peephole Optimize cleans up many "is None" checks and creates constant prop
// opportunities
PeepholeOptimize(graph);
// // AliasDb construction can be slow, so run it just on immutable types
// // to clean up constant Ifs & other easy wins
ConstantPropagationImmutableTypes(graph);
ConstantPooling(graph);
}

} // namespace jit
Expand Down

0 comments on commit 6c722d3

Please sign in to comment.