diff --git a/test/expect/TestTensorBoard.test_pytorch_graph.expect b/test/expect/TestTensorBoard.test_pytorch_graph.expect index b13e4a80c381a6..52d232c98778e5 100644 --- a/test/expect/TestTensorBoard.test_pytorch_graph.expect +++ b/test/expect/TestTensorBoard.test_pytorch_graph.expect @@ -26,7 +26,7 @@ node { node { name: "output/output.1" op: "IO Node" - input: "myLinear/Linear[l]/22" + input: "myLinear/Linear[l]/21" attr { key: "_output_shapes" value { @@ -50,7 +50,17 @@ node { } } node { - name: "myLinear/Linear[l]/bias/17" + name: "myLinear/Linear[l]/17" + op: "prim::Constant" + attr { + key: "attr" + value { + s: "{ value : 1}" + } + } +} +node { + name: "myLinear/Linear[l]/bias/18" op: "prim::GetAttr" input: "myLinear/Linear[l]/weight/14" attr { @@ -61,7 +71,7 @@ node { } } node { - name: "myLinear/Linear[l]/weight/18" + name: "myLinear/Linear[l]/weight/19" op: "prim::GetAttr" input: "myLinear/Linear[l]/weight/14" attr { @@ -72,9 +82,9 @@ node { } } node { - name: "myLinear/Linear[l]/19" + name: "myLinear/Linear[l]/20" op: "aten::t" - input: "myLinear/Linear[l]/weight/18" + input: "myLinear/Linear[l]/weight/19" attr { key: "_output_shapes" value { @@ -97,34 +107,14 @@ node { } } } -node { - name: "myLinear/Linear[l]/20" - op: "prim::Constant" - attr { - key: "attr" - value { - s: "{ value : 1}" - } - } -} node { name: "myLinear/Linear[l]/21" - op: "prim::Constant" - attr { - key: "attr" - value { - s: "{ value : 1}" - } - } -} -node { - name: "myLinear/Linear[l]/22" op: "aten::addmm" - input: "myLinear/Linear[l]/bias/17" + input: "myLinear/Linear[l]/bias/18" input: "input/input" - input: "myLinear/Linear[l]/19" input: "myLinear/Linear[l]/20" - input: "myLinear/Linear[l]/21" + input: "myLinear/Linear[l]/17" + input: "myLinear/Linear[l]/17" attr { key: "_output_shapes" value { diff --git a/test/test_jit.py b/test/test_jit.py index 82538a96d3c4a7..bca71e629e4fb5 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -5089,7 +5089,7 @@ def fn2(x): with self.capture_stdout(): traced = torch.jit.trace(fn, [torch.ones(2, 2)]) - FileCheck().check("goodbye").check("hello").run(traced.graph) + FileCheck().check("goodbye").run(traced.graph) def test_big_int_literals(self): def ok(): @@ -11360,6 +11360,18 @@ def check(mod): imported = self.getExportImportCopy(traced) check(imported) + def test_inlining_cleanup(self): + def foo(x): + return F.linear(x, x) + + @torch.jit.script + def fee(x): + return foo(x) + + # inlining optimizations should have cleaned up linear if statement + self.run_pass("inline", fee.graph) + FileCheck().check_not("prim::If").run(fee.graph) + def test_trace_export_fns_recursive(self): class Foo(torch.nn.Module): def __init__(self): diff --git a/torch/csrc/jit/api/function_impl.cpp b/torch/csrc/jit/api/function_impl.cpp index c4eabda759ab1a..712f501d43eda6 100644 --- a/torch/csrc/jit/api/function_impl.cpp +++ b/torch/csrc/jit/api/function_impl.cpp @@ -2,6 +2,9 @@ #include #include +#include +#include +#include namespace torch { namespace jit { @@ -66,9 +69,14 @@ const c10::FunctionSchema& GraphFunction::getSchema() const { } void preoptimizeGraph(std::shared_ptr& graph) { - // TODO: Invoke cleanup passes before and after inlining to reduce amount of - // code we're copying. Inline(*graph); + // Peephole Optimize cleans up many "is None" checks and creates constant prop + // opportunities + PeepholeOptimize(graph); + // // AliasDb construction can be slow, so run it just on immutable types + // // to clean up constant Ifs & other easy wins + ConstantPropagationImmutableTypes(graph); + ConstantPooling(graph); } } // namespace jit