diff --git a/src/passes/Inlining.cpp b/src/passes/Inlining.cpp index 876c186d471..2c5db3e5d3a 100644 --- a/src/passes/Inlining.cpp +++ b/src/passes/Inlining.cpp @@ -46,6 +46,7 @@ #include "pass.h" #include "passes/opt-utils.h" #include "wasm-builder.h" +#include "wasm-traversal.h" #include "wasm.h" namespace wasm { @@ -175,6 +176,62 @@ static bool canHandleParams(Function* func) { using NameInfoMap = std::unordered_map; +struct InlineMeasurer + : public PostWalker> { + Index size = 0; + Index orderedGetCount = 0; + Index expectedGetIndex = 0; + bool isViolationOrderedAssumption = false; + + void visitExpression(Expression* curr) { + size++; + // calculate the local.get count which should be duplicate after inlining + if (isViolationOrderedAssumption) { + return; + } + if (invalidOperation(curr)) { + isViolationOrderedAssumption = true; + return; + } + if (LocalGet* localGet = curr->dynCast()) { + if (localGet->index >= expectedGetIndex) { + expectedGetIndex = localGet->index + 1; + orderedGetCount++; + } else { + // duplicated local.get, fallback to normal cases + // it will introduce temporary locals after inlining + isViolationOrderedAssumption = true; + } + } + } + + bool invalidOperation(Expression* curr) { + return curr->is() || curr->is() || curr->is(); + } + + // Measure the number of expressions for inlining purposes. This is similar to + // Measurer::measure, but ignore the local sequence of at the beginning of + // function body. + static Index measure(Function* func) { + InlineMeasurer measurer; + measurer.isViolationOrderedAssumption = func->getNumVars() != 0; + measurer.walkFunction(func); + // we don't count the local.get when the order is same as the + // parameters' order. + // It can enable the inlining for the function like: + // (func $foo (param $x i32) (param $y i32) + // (call $bar + // (local.get $x) + // (local.get $y) + // ) + // ) + return measurer.size - (measurer.isViolationOrderedAssumption + ? 0 + : measurer.orderedGetCount); + } +}; + struct FunctionInfoScanner : public WalkerPass> { bool isFunctionParallel() override { return true; } @@ -224,7 +281,7 @@ struct FunctionInfoScanner info.inliningMode = InliningMode::Uninlineable; } - info.size = Measurer::measure(curr->body); + info.size = InlineMeasurer::measure(curr); if (auto* call = curr->body->dynCast()) { if (info.size == call->operands.size() + 1) { diff --git a/test/lit/passes/inlining-small-function-always.wast b/test/lit/passes/inlining-small-function-always.wast new file mode 100644 index 00000000000..6656c20fde8 --- /dev/null +++ b/test/lit/passes/inlining-small-function-always.wast @@ -0,0 +1,253 @@ +;; NOTE: Assertions have been generated by update_lit_checks.py --all-items and should not be edited. +;; RUN: foreach %s %t wasm-opt -all --inlining --shrink-level=2 -S -o - | filecheck %s + +(module $call_two + ;; CHECK: (type $0 (func (param i32 i32) (result i32))) + + ;; CHECK: (type $1 (func (param i32) (result i32))) + + ;; CHECK: (export "add" (func $add)) + (export "add" (func $add)) + + ;; CHECK: (func $add (type $0) (param $0 i32) (param $1 i32) (result i32) + ;; CHECK-NEXT: (i32.add + ;; CHECK-NEXT: (local.get $0) + ;; CHECK-NEXT: (local.get $1) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: ) + (func $add (param i32) (param i32) (result i32) + (i32.add (local.get 0) (local.get 1)) + ) + + ;; CHECK: (func $call_add (type $1) (param $0 i32) (result i32) + ;; CHECK-NEXT: (local $1 i32) + ;; CHECK-NEXT: (local $2 i32) + ;; CHECK-NEXT: (block $__inlined_func$add (result i32) + ;; CHECK-NEXT: (local.set $1 + ;; CHECK-NEXT: (local.get $0) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: (local.set $2 + ;; CHECK-NEXT: (i32.const 1) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: (i32.add + ;; CHECK-NEXT: (local.get $1) + ;; CHECK-NEXT: (local.get $2) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: ) + (func $call_add (param i32) (result i32) + (call $add (local.get 0) (i32.const 1)) + ) +) + +(module $call_three + ;; CHECK: (type $0 (func (param i32 i32 i32) (result i32))) + + ;; CHECK: (type $1 (func (param i32) (result i32))) + + ;; CHECK: (export "callee" (func $callee)) + (export "callee" (func $callee)) + + ;; CHECK: (func $callee (type $0) (param $0 i32) (param $1 i32) (param $2 i32) (result i32) + ;; CHECK-NEXT: (i32.add + ;; CHECK-NEXT: (local.get $0) + ;; CHECK-NEXT: (i32.add + ;; CHECK-NEXT: (local.get $1) + ;; CHECK-NEXT: (local.get $2) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: ) + (func $callee (param i32 i32 i32) (result i32) + (i32.add + (local.get 0) + (i32.add (local.get 1) (local.get 2)) + ) + ) + + ;; CHECK: (func $caller (type $1) (param $0 i32) (result i32) + ;; CHECK-NEXT: (local $1 i32) + ;; CHECK-NEXT: (local $2 i32) + ;; CHECK-NEXT: (local $3 i32) + ;; CHECK-NEXT: (block $__inlined_func$callee (result i32) + ;; CHECK-NEXT: (local.set $1 + ;; CHECK-NEXT: (local.get $0) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: (local.set $2 + ;; CHECK-NEXT: (i32.const 1) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: (local.set $3 + ;; CHECK-NEXT: (local.get $0) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: (i32.add + ;; CHECK-NEXT: (local.get $1) + ;; CHECK-NEXT: (i32.add + ;; CHECK-NEXT: (local.get $2) + ;; CHECK-NEXT: (local.get $3) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: ) + (func $caller (param i32) (result i32) + (call $callee (local.get 0) (i32.const 1) (local.get 0)) + ) +) + +(module $skip_part_of_parameters + ;; CHECK: (type $0 (func (param i32 i32 i32) (result i32))) + + ;; CHECK: (type $1 (func (param i32) (result i32))) + + ;; CHECK: (export "callee" (func $callee)) + (export "callee" (func $callee)) + + ;; CHECK: (func $callee (type $0) (param $0 i32) (param $1 i32) (param $2 i32) (result i32) + ;; CHECK-NEXT: (i32.add + ;; CHECK-NEXT: (local.get $0) + ;; CHECK-NEXT: (local.get $2) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: ) + (func $callee (param i32 i32 i32) (result i32) + (i32.add + (local.get 0) + (local.get 2) + ) + ) + + ;; CHECK: (func $caller (type $1) (param $0 i32) (result i32) + ;; CHECK-NEXT: (local $1 i32) + ;; CHECK-NEXT: (local $2 i32) + ;; CHECK-NEXT: (local $3 i32) + ;; CHECK-NEXT: (block $__inlined_func$callee (result i32) + ;; CHECK-NEXT: (local.set $1 + ;; CHECK-NEXT: (local.get $0) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: (local.set $2 + ;; CHECK-NEXT: (i32.const 1) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: (local.set $3 + ;; CHECK-NEXT: (local.get $0) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: (i32.add + ;; CHECK-NEXT: (local.get $1) + ;; CHECK-NEXT: (local.get $3) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: ) + (func $caller (param i32) (result i32) + (call $callee (local.get 0) (i32.const 1) (local.get 0)) + ) +) + + +(module $wrong_order + ;; CHECK: (type $0 (func (param i32 i32) (result i32))) + + ;; CHECK: (type $1 (func)) + + ;; CHECK: (export "callee1" (func $callee1)) + (export "callee1" (func $callee1)) + + ;; CHECK: (export "callee2" (func $callee2)) + + ;; CHECK: (func $callee1 (type $0) (param $0 i32) (param $1 i32) (result i32) + ;; CHECK-NEXT: (i32.add + ;; CHECK-NEXT: (local.get $1) + ;; CHECK-NEXT: (local.get $0) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: ) + (func $callee1 (param i32 i32) (result i32) + (i32.add (local.get 1) (local.get 0)) + ) + (export "callee2" (func $callee2)) + + ;; CHECK: (func $callee2 (type $0) (param $0 i32) (param $1 i32) (result i32) + ;; CHECK-NEXT: (i32.add + ;; CHECK-NEXT: (local.get $0) + ;; CHECK-NEXT: (local.get $0) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: ) + (func $callee2 (param i32 i32) (result i32) + (i32.add (local.get 0) (local.get 0)) + ) + + ;; CHECK: (func $caller (type $1) + ;; CHECK-NEXT: (drop + ;; CHECK-NEXT: (call $callee1 + ;; CHECK-NEXT: (i32.const 0) + ;; CHECK-NEXT: (i32.const 1) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: (drop + ;; CHECK-NEXT: (call $callee2 + ;; CHECK-NEXT: (i32.const 0) + ;; CHECK-NEXT: (i32.const 1) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: ) + (func $caller + (drop (call $callee1 (i32.const 0) (i32.const 1))) + (drop (call $callee2 (i32.const 0) (i32.const 1))) + ) +) + +(module $non_parameters + ;; CHECK: (type $0 (func (result i32))) + + ;; CHECK: (type $1 (func)) + + ;; CHECK: (export "callee1" (func $callee1)) + (export "callee1" (func $callee1)) + + ;; CHECK: (func $callee1 (type $0) (result i32) + ;; CHECK-NEXT: (local $0 i32) + ;; CHECK-NEXT: (local $1 i32) + ;; CHECK-NEXT: (i32.add + ;; CHECK-NEXT: (local.get $0) + ;; CHECK-NEXT: (local.get $1) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: ) + (func $callee1 (result i32) + (local i32 i32) + (i32.add (local.get 0) (local.get 1)) + ) + + ;; CHECK: (func $caller (type $1) + ;; CHECK-NEXT: (drop + ;; CHECK-NEXT: (call $callee1) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: ) + (func $caller + (drop (call $callee1)) + ) +) + +(module $non_parameters + ;; CHECK: (type $0 (func (result i32))) + + ;; CHECK: (type $1 (func)) + + ;; CHECK: (export "callee1" (func $callee1)) + (export "callee1" (func $callee1)) + + ;; CHECK: (func $callee1 (type $0) (result i32) + ;; CHECK-NEXT: (local $0 i32) + ;; CHECK-NEXT: (local $1 i32) + ;; CHECK-NEXT: (i32.add + ;; CHECK-NEXT: (local.get $0) + ;; CHECK-NEXT: (local.get $1) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: ) + (func $callee1 (result i32) + (local i32 i32) + (i32.add (local.get 0) (local.get 1)) + ) + + ;; CHECK: (func $caller (type $1) + ;; CHECK-NEXT: (drop + ;; CHECK-NEXT: (call $callee1) + ;; CHECK-NEXT: ) + ;; CHECK-NEXT: ) + (func $caller + (drop (call $callee1)) + ) +) diff --git a/test/lit/passes/no-inline.wast b/test/lit/passes/no-inline.wast index 2e21105cb41..2512c0bc831 100644 --- a/test/lit/passes/no-inline.wast +++ b/test/lit/passes/no-inline.wast @@ -509,7 +509,9 @@ ;; NO_PART: (func $maybe-partial-or-full-1 (param $x i32) ;; NO_PART-NEXT: (if - ;; NO_PART-NEXT: (local.get $x) + ;; NO_PART-NEXT: (i32.eqz + ;; NO_PART-NEXT: (local.get $x) + ;; NO_PART-NEXT: ) ;; NO_PART-NEXT: (then ;; NO_PART-NEXT: (call $import) ;; NO_PART-NEXT: ) @@ -517,7 +519,9 @@ ;; NO_PART-NEXT: ) ;; NO_BOTH: (func $maybe-partial-or-full-1 (param $x i32) ;; NO_BOTH-NEXT: (if - ;; NO_BOTH-NEXT: (local.get $x) + ;; NO_BOTH-NEXT: (i32.eqz + ;; NO_BOTH-NEXT: (local.get $x) + ;; NO_BOTH-NEXT: ) ;; NO_BOTH-NEXT: (then ;; NO_BOTH-NEXT: (call $import) ;; NO_BOTH-NEXT: ) @@ -529,7 +533,7 @@ ;; inlining is disabled but partial inlining is enabled, we should only ;; partially inline it. (if - (local.get $x) + (i32.eqz (local.get $x)) (then (call $import) ) @@ -538,7 +542,9 @@ ;; NO_PART: (func $maybe-partial-or-full-2 (param $x i32) ;; NO_PART-NEXT: (if - ;; NO_PART-NEXT: (local.get $x) + ;; NO_PART-NEXT: (i32.eqz + ;; NO_PART-NEXT: (local.get $x) + ;; NO_PART-NEXT: ) ;; NO_PART-NEXT: (then ;; NO_PART-NEXT: (return) ;; NO_PART-NEXT: ) @@ -571,7 +577,9 @@ ;; NO_PART-NEXT: ) ;; NO_BOTH: (func $maybe-partial-or-full-2 (param $x i32) ;; NO_BOTH-NEXT: (if - ;; NO_BOTH-NEXT: (local.get $x) + ;; NO_BOTH-NEXT: (i32.eqz + ;; NO_BOTH-NEXT: (local.get $x) + ;; NO_BOTH-NEXT: ) ;; NO_BOTH-NEXT: (then ;; NO_BOTH-NEXT: (return) ;; NO_BOTH-NEXT: ) @@ -606,7 +614,7 @@ ;; As above, but for another form of partial inlining. Here we need to add ;; some extra things to the function size for partial inlining to kick in. (if - (local.get $x) + (i32.eqz (local.get $x)) (then (return) ) @@ -648,7 +656,9 @@ ;; YES_ALL-NEXT: (i32.const 0) ;; YES_ALL-NEXT: ) ;; YES_ALL-NEXT: (if - ;; YES_ALL-NEXT: (local.get $0) + ;; YES_ALL-NEXT: (i32.eqz + ;; YES_ALL-NEXT: (local.get $0) + ;; YES_ALL-NEXT: ) ;; YES_ALL-NEXT: (then ;; YES_ALL-NEXT: (call $import) ;; YES_ALL-NEXT: ) @@ -659,7 +669,9 @@ ;; YES_ALL-NEXT: (i32.const 1) ;; YES_ALL-NEXT: ) ;; YES_ALL-NEXT: (if - ;; YES_ALL-NEXT: (local.get $1) + ;; YES_ALL-NEXT: (i32.eqz + ;; YES_ALL-NEXT: (local.get $1) + ;; YES_ALL-NEXT: ) ;; YES_ALL-NEXT: (then ;; YES_ALL-NEXT: (call $import) ;; YES_ALL-NEXT: ) @@ -671,7 +683,9 @@ ;; YES_ALL-NEXT: ) ;; YES_ALL-NEXT: (block ;; YES_ALL-NEXT: (if - ;; YES_ALL-NEXT: (local.get $2) + ;; YES_ALL-NEXT: (i32.eqz + ;; YES_ALL-NEXT: (local.get $2) + ;; YES_ALL-NEXT: ) ;; YES_ALL-NEXT: (then ;; YES_ALL-NEXT: (br $__inlined_func$maybe-partial-or-full-2$2) ;; YES_ALL-NEXT: ) @@ -709,7 +723,9 @@ ;; YES_ALL-NEXT: ) ;; YES_ALL-NEXT: (block ;; YES_ALL-NEXT: (if - ;; YES_ALL-NEXT: (local.get $3) + ;; YES_ALL-NEXT: (i32.eqz + ;; YES_ALL-NEXT: (local.get $3) + ;; YES_ALL-NEXT: ) ;; YES_ALL-NEXT: (then ;; YES_ALL-NEXT: (br $__inlined_func$maybe-partial-or-full-2$3) ;; YES_ALL-NEXT: ) @@ -766,7 +782,9 @@ ;; NO_FULL-NEXT: (i32.const 0) ;; NO_FULL-NEXT: ) ;; NO_FULL-NEXT: (if - ;; NO_FULL-NEXT: (local.get $0) + ;; NO_FULL-NEXT: (i32.eqz + ;; NO_FULL-NEXT: (local.get $0) + ;; NO_FULL-NEXT: ) ;; NO_FULL-NEXT: (then ;; NO_FULL-NEXT: (call $byn-split-outlined-B$maybe-partial-or-full-1 ;; NO_FULL-NEXT: (local.get $0) @@ -779,7 +797,9 @@ ;; NO_FULL-NEXT: (i32.const 1) ;; NO_FULL-NEXT: ) ;; NO_FULL-NEXT: (if - ;; NO_FULL-NEXT: (local.get $1) + ;; NO_FULL-NEXT: (i32.eqz + ;; NO_FULL-NEXT: (local.get $1) + ;; NO_FULL-NEXT: ) ;; NO_FULL-NEXT: (then ;; NO_FULL-NEXT: (call $byn-split-outlined-B$maybe-partial-or-full-1 ;; NO_FULL-NEXT: (local.get $1) @@ -793,7 +813,9 @@ ;; NO_FULL-NEXT: ) ;; NO_FULL-NEXT: (if ;; NO_FULL-NEXT: (i32.eqz - ;; NO_FULL-NEXT: (local.get $2) + ;; NO_FULL-NEXT: (i32.eqz + ;; NO_FULL-NEXT: (local.get $2) + ;; NO_FULL-NEXT: ) ;; NO_FULL-NEXT: ) ;; NO_FULL-NEXT: (then ;; NO_FULL-NEXT: (call $byn-split-outlined-A$maybe-partial-or-full-2 @@ -808,7 +830,9 @@ ;; NO_FULL-NEXT: ) ;; NO_FULL-NEXT: (if ;; NO_FULL-NEXT: (i32.eqz - ;; NO_FULL-NEXT: (local.get $3) + ;; NO_FULL-NEXT: (i32.eqz + ;; NO_FULL-NEXT: (local.get $3) + ;; NO_FULL-NEXT: ) ;; NO_FULL-NEXT: ) ;; NO_FULL-NEXT: (then ;; NO_FULL-NEXT: (call $byn-split-outlined-A$maybe-partial-or-full-2