Skip to content

Commit

Permalink
improve tail call optimization (close #86)
Browse files Browse the repository at this point in the history
Previously it optimizes only functions with no local variables, like _while and
_until, using opjump instruction. This commit also optimizes functions with
local variables but with no arguments and there is no following fork point
because local variables may be referenced after backtracking. For tail call
with following fork points, it now uses new opcallrec instruction to
optimize return address but allocates for local variables.
  • Loading branch information
itchyny committed Jul 22, 2021
1 parent d47bd6b commit 97658f3
Show file tree
Hide file tree
Showing 5 changed files with 98 additions and 20 deletions.
3 changes: 3 additions & 0 deletions code.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ const (
opjump
opjumpifnot
opcall
opcallrec
oppushpc
opcallpc
opscope
Expand Down Expand Up @@ -75,6 +76,8 @@ func (op opcode) String() string {
return "jumpifnot"
case opcall:
return "call"
case opcallrec:
return "callrec"
case oppushpc:
return "pushpc"
case opcallpc:
Expand Down
28 changes: 21 additions & 7 deletions compiler.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ func Compile(q *Query, options ...CompilerOption) (*Code, error) {
scope := c.newScope()
c.scopes = []*scopeinfo{scope}
defer c.lazy(func() *code {
return &code{op: opscope, v: [2]int{scope.id, scope.variablecnt}}
return &code{op: opscope, v: [3]int{scope.id, scope.variablecnt, 0}}
})()
if c.moduleLoader != nil {
if moduleLoader, ok := c.moduleLoader.(interface {
Expand Down Expand Up @@ -319,7 +319,7 @@ func (c *compiler) compileFuncDef(e *FuncDef, builtin bool) error {
scope = c.newScope()
c.scopes = append(c.scopes, scope)
defer c.lazy(func() *code {
return &code{op: opscope, v: [2]int{scope.id, scope.variablecnt}}
return &code{op: opscope, v: [3]int{scope.id, scope.variablecnt, len(e.Args)}}
})()
if len(e.Args) > 0 {
v := c.newVariable()
Expand Down Expand Up @@ -1439,28 +1439,42 @@ func (c *compiler) lazy(f func() *code) func() {

func (c *compiler) optimizeTailRec() {
var pcs []int
targets := map[int]struct{}{}
scopes := map[int]struct{}{}
forked := map[int]struct{}{}
L:
for i, l := 0, len(c.codes); i < l; i++ {
switch c.codes[i].op {
case opscope:
pcs = append(pcs, i)
if c.codes[i].v.([2]int)[1] == 0 {
targets[i] = struct{}{}
if c.codes[i].v.([3]int)[2] == 0 {
scopes[i] = struct{}{}
}
case opfork, opforktrybegin, opforkalt:
forked[c.codes[i].v.(int)] = struct{}{}
case opcall:
if j, ok := c.codes[i].v.(int); !ok ||
len(pcs) == 0 || pcs[len(pcs)-1] != j {
break
} else if _, ok := targets[j]; !ok {
} else if _, ok = scopes[j]; !ok {
break
}
canjump := true
for j := i + 1; j < l; {
switch c.codes[j].op {
case opjump:
j = c.codes[j].v.(int)
if canjump {
if _, ok := forked[j+1]; ok {
canjump = false
}
}
case opret:
c.codes[i] = &code{op: opjump, v: pcs[len(pcs)-1] + 1}
if canjump {
c.codes[i].op = opjump
c.codes[i].v = pcs[len(pcs)-1] + 1
} else {
c.codes[i].op = opcallrec
}
continue L
default:
continue L
Expand Down
44 changes: 43 additions & 1 deletion compiler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ func TestCodeCompile_OptimizeConstants(t *testing.T) {
}
}

func TestCodeCompile_OptimizeTailRec(t *testing.T) {
func TestCodeCompile_OptimizeTailRec_Range(t *testing.T) {
query, err := gojq.Parse("range(10)")
if err != nil {
t.Fatal(err)
Expand Down Expand Up @@ -166,6 +166,48 @@ func TestCodeCompile_OptimizeTailRec(t *testing.T) {
}
}

func TestCodeCompile_OptimizeTailRec_ScopeVar(t *testing.T) {
query, err := gojq.Parse("def f: . as $x | $x, (if $x < 3 then $x + 1 | f else empty end); f")
if err != nil {
t.Fatal(err)
}
code, err := gojq.Compile(query)
if err != nil {
t.Fatal(err)
}
codes := reflect.ValueOf(code).Elem().FieldByName("codes")
if got, expected := codes.Len(), 43; expected != got {
t.Errorf("expected: %v, got: %v", expected, got)
}
op1 := codes.Index(37).Elem().FieldByName("op") // call f by jump
op2 := codes.Index(38).Elem().FieldByName("op")
if got, expected := *(*int)(unsafe.Pointer(op2.UnsafeAddr())),
*(*int)(unsafe.Pointer(op1.UnsafeAddr())); expected != got {
t.Errorf("expected: %v, got: %v", expected, got)
}
}

func TestCodeCompile_OptimizeTailRec_CallRec(t *testing.T) {
query, err := gojq.Parse("def f: . as $x | $x, (if $x < 3 then $x + 1 | f else empty end), $x; f")
if err != nil {
t.Fatal(err)
}
code, err := gojq.Compile(query)
if err != nil {
t.Fatal(err)
}
codes := reflect.ValueOf(code).Elem().FieldByName("codes")
if got, expected := codes.Len(), 47; expected != got {
t.Errorf("expected: %v, got: %v", expected, got)
}
op1 := codes.Index(38).Elem().FieldByName("op") // callrec f
op2 := codes.Index(37).Elem().FieldByName("op") // call _add/2
if got, expected := *(*int)(unsafe.Pointer(op2.UnsafeAddr()))+1,
*(*int)(unsafe.Pointer(op1.UnsafeAddr())); expected != got {
t.Errorf("expected: %v, got: %v", expected, got)
}
}

func TestCodeCompile_OptimizeJumps(t *testing.T) {
query, err := gojq.Parse("def f: 1; def g: 2; def h: 3; f")
if err != nil {
Expand Down
31 changes: 21 additions & 10 deletions debug.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ func (env *env) debugCodes() {
for i, c := range env.codes {
pc := i
switch c.op {
case opcall:
case opcall, opcallrec:
if x, ok := c.v.(int); ok {
pc = x
}
Expand All @@ -87,9 +87,14 @@ func (env *env) debugCodes() {
}
var s string
if name := env.lookupInfoName(pc); name != "" {
if (c.op == opcall || c.op == opjump) && !strings.HasPrefix(name, "module ") {
s = "\t## call " + name
} else {
switch c.op {
case opcall, opcallrec, opjump:
if !strings.HasPrefix(name, "module ") {
s = "\t## call " + name
break
}
fallthrough
default:
s = "\t## " + name
}
}
Expand All @@ -114,7 +119,7 @@ func (env *env) debugState(pc int, backtrack bool) {
sb.WriteString(debugJSON(env.stack.data[xs[i]].value))
}
switch c.op {
case opcall:
case opcall, opcallrec:
if x, ok := c.v.(int); ok {
pc = x
}
Expand All @@ -125,9 +130,14 @@ func (env *env) debugState(pc int, backtrack bool) {
}
}
if name := env.lookupInfoName(pc); name != "" {
if (c.op == opcall || c.op == opjump) && !strings.HasPrefix(name, "module ") {
sb.WriteString("\t\t\t## call " + name)
} else {
switch c.op {
case opcall, opcallrec, opjump:
if !strings.HasPrefix(name, "module ") {
sb.WriteString("\t\t\t## call " + name)
break
}
fallthrough
default:
sb.WriteString("\t\t\t## " + name)
}
}
Expand Down Expand Up @@ -162,7 +172,8 @@ func (env *env) debugForks(pc int, op string) {
}

func debugOperand(c *code) string {
if c.op == opcall {
switch c.op {
case opcall, opcallrec:
switch v := c.v.(type) {
case int:
return debugJSON(v)
Expand All @@ -171,7 +182,7 @@ func debugOperand(c *code) string {
default:
panic(c)
}
} else {
default:
return debugJSON(c.v)
}
}
Expand Down
12 changes: 10 additions & 2 deletions execute.go
Original file line number Diff line number Diff line change
Expand Up @@ -187,17 +187,25 @@ loop:
default:
panic(v)
}
case opcallrec:
pc, callpc, index = code.v.(int), -pc, env.scopes.index
goto loop
case oppushpc:
env.push([2]int{code.v.(int), env.scopes.index})
case opcallpc:
xs := env.pop().([2]int)
pc, callpc, index = xs[0], pc, xs[1]
goto loop
case opscope:
xs := code.v.([2]int)
xs := code.v.([3]int)
var i, l int
if index == env.scopes.index {
i = index
if callpc >= 0 {
i = index
} else {
callpc = -callpc
i = env.scopes.top().(scope).saveindex
}
} else {
env.scopes.save(&i, &l)
env.scopes.index = index
Expand Down

0 comments on commit 97658f3

Please sign in to comment.