diff --git a/code.go b/code.go index e44a78f3..f1935d84 100644 --- a/code.go +++ b/code.go @@ -26,6 +26,7 @@ const ( opjump opjumpifnot opcall + opcallrec oppushpc opcallpc opscope @@ -75,6 +76,8 @@ func (op opcode) String() string { return "jumpifnot" case opcall: return "call" + case opcallrec: + return "callrec" case oppushpc: return "pushpc" case opcallpc: diff --git a/compiler.go b/compiler.go index e09e3c02..98fd6b1c 100644 --- a/compiler.go +++ b/compiler.go @@ -96,7 +96,7 @@ func Compile(q *Query, options ...CompilerOption) (*Code, error) { scope := c.newScope() c.scopes = []*scopeinfo{scope} defer c.lazy(func() *code { - return &code{op: opscope, v: [2]int{scope.id, scope.variablecnt}} + return &code{op: opscope, v: [3]int{scope.id, scope.variablecnt, 0}} })() if c.moduleLoader != nil { if moduleLoader, ok := c.moduleLoader.(interface { @@ -319,7 +319,7 @@ func (c *compiler) compileFuncDef(e *FuncDef, builtin bool) error { scope = c.newScope() c.scopes = append(c.scopes, scope) defer c.lazy(func() *code { - return &code{op: opscope, v: [2]int{scope.id, scope.variablecnt}} + return &code{op: opscope, v: [3]int{scope.id, scope.variablecnt, len(e.Args)}} })() if len(e.Args) > 0 { v := c.newVariable() @@ -1439,28 +1439,42 @@ func (c *compiler) lazy(f func() *code) func() { func (c *compiler) optimizeTailRec() { var pcs []int - targets := map[int]struct{}{} + scopes := map[int]struct{}{} + forked := map[int]struct{}{} L: for i, l := 0, len(c.codes); i < l; i++ { switch c.codes[i].op { case opscope: pcs = append(pcs, i) - if c.codes[i].v.([2]int)[1] == 0 { - targets[i] = struct{}{} + if c.codes[i].v.([3]int)[2] == 0 { + scopes[i] = struct{}{} } + case opfork, opforktrybegin, opforkalt: + forked[c.codes[i].v.(int)] = struct{}{} case opcall: if j, ok := c.codes[i].v.(int); !ok || len(pcs) == 0 || pcs[len(pcs)-1] != j { break - } else if _, ok := targets[j]; !ok { + } else if _, ok = scopes[j]; !ok { break } + canjump := true for j := i + 1; j < l; { switch c.codes[j].op { case opjump: j = c.codes[j].v.(int) + if canjump { + if _, ok := forked[j+1]; ok { + canjump = false + } + } case opret: - c.codes[i] = &code{op: opjump, v: pcs[len(pcs)-1] + 1} + if canjump { + c.codes[i].op = opjump + c.codes[i].v = pcs[len(pcs)-1] + 1 + } else { + c.codes[i].op = opcallrec + } continue L default: continue L diff --git a/compiler_test.go b/compiler_test.go index 870f0b26..ca3ebea8 100644 --- a/compiler_test.go +++ b/compiler_test.go @@ -130,7 +130,7 @@ func TestCodeCompile_OptimizeConstants(t *testing.T) { } } -func TestCodeCompile_OptimizeTailRec(t *testing.T) { +func TestCodeCompile_OptimizeTailRec_Range(t *testing.T) { query, err := gojq.Parse("range(10)") if err != nil { t.Fatal(err) @@ -166,6 +166,48 @@ func TestCodeCompile_OptimizeTailRec(t *testing.T) { } } +func TestCodeCompile_OptimizeTailRec_ScopeVar(t *testing.T) { + query, err := gojq.Parse("def f: . as $x | $x, (if $x < 3 then $x + 1 | f else empty end); f") + if err != nil { + t.Fatal(err) + } + code, err := gojq.Compile(query) + if err != nil { + t.Fatal(err) + } + codes := reflect.ValueOf(code).Elem().FieldByName("codes") + if got, expected := codes.Len(), 43; expected != got { + t.Errorf("expected: %v, got: %v", expected, got) + } + op1 := codes.Index(37).Elem().FieldByName("op") // call f by jump + op2 := codes.Index(38).Elem().FieldByName("op") + if got, expected := *(*int)(unsafe.Pointer(op2.UnsafeAddr())), + *(*int)(unsafe.Pointer(op1.UnsafeAddr())); expected != got { + t.Errorf("expected: %v, got: %v", expected, got) + } +} + +func TestCodeCompile_OptimizeTailRec_CallRec(t *testing.T) { + query, err := gojq.Parse("def f: . as $x | $x, (if $x < 3 then $x + 1 | f else empty end), $x; f") + if err != nil { + t.Fatal(err) + } + code, err := gojq.Compile(query) + if err != nil { + t.Fatal(err) + } + codes := reflect.ValueOf(code).Elem().FieldByName("codes") + if got, expected := codes.Len(), 47; expected != got { + t.Errorf("expected: %v, got: %v", expected, got) + } + op1 := codes.Index(38).Elem().FieldByName("op") // callrec f + op2 := codes.Index(37).Elem().FieldByName("op") // call _add/2 + if got, expected := *(*int)(unsafe.Pointer(op2.UnsafeAddr()))+1, + *(*int)(unsafe.Pointer(op1.UnsafeAddr())); expected != got { + t.Errorf("expected: %v, got: %v", expected, got) + } +} + func TestCodeCompile_OptimizeJumps(t *testing.T) { query, err := gojq.Parse("def f: 1; def g: 2; def h: 3; f") if err != nil { diff --git a/debug.go b/debug.go index 305a00ab..66635339 100644 --- a/debug.go +++ b/debug.go @@ -75,7 +75,7 @@ func (env *env) debugCodes() { for i, c := range env.codes { pc := i switch c.op { - case opcall: + case opcall, opcallrec: if x, ok := c.v.(int); ok { pc = x } @@ -87,9 +87,14 @@ func (env *env) debugCodes() { } var s string if name := env.lookupInfoName(pc); name != "" { - if (c.op == opcall || c.op == opjump) && !strings.HasPrefix(name, "module ") { - s = "\t## call " + name - } else { + switch c.op { + case opcall, opcallrec, opjump: + if !strings.HasPrefix(name, "module ") { + s = "\t## call " + name + break + } + fallthrough + default: s = "\t## " + name } } @@ -114,7 +119,7 @@ func (env *env) debugState(pc int, backtrack bool) { sb.WriteString(debugJSON(env.stack.data[xs[i]].value)) } switch c.op { - case opcall: + case opcall, opcallrec: if x, ok := c.v.(int); ok { pc = x } @@ -125,9 +130,14 @@ func (env *env) debugState(pc int, backtrack bool) { } } if name := env.lookupInfoName(pc); name != "" { - if (c.op == opcall || c.op == opjump) && !strings.HasPrefix(name, "module ") { - sb.WriteString("\t\t\t## call " + name) - } else { + switch c.op { + case opcall, opcallrec, opjump: + if !strings.HasPrefix(name, "module ") { + sb.WriteString("\t\t\t## call " + name) + break + } + fallthrough + default: sb.WriteString("\t\t\t## " + name) } } @@ -162,7 +172,8 @@ func (env *env) debugForks(pc int, op string) { } func debugOperand(c *code) string { - if c.op == opcall { + switch c.op { + case opcall, opcallrec: switch v := c.v.(type) { case int: return debugJSON(v) @@ -171,7 +182,7 @@ func debugOperand(c *code) string { default: panic(c) } - } else { + default: return debugJSON(c.v) } } diff --git a/execute.go b/execute.go index 32b395dc..d6517fe2 100644 --- a/execute.go +++ b/execute.go @@ -187,6 +187,9 @@ loop: default: panic(v) } + case opcallrec: + pc, callpc, index = code.v.(int), -pc, env.scopes.index + goto loop case oppushpc: env.push([2]int{code.v.(int), env.scopes.index}) case opcallpc: @@ -194,10 +197,15 @@ loop: pc, callpc, index = xs[0], pc, xs[1] goto loop case opscope: - xs := code.v.([2]int) + xs := code.v.([3]int) var i, l int if index == env.scopes.index { - i = index + if callpc >= 0 { + i = index + } else { + callpc = -callpc + i = env.scopes.top().(scope).saveindex + } } else { env.scopes.save(&i, &l) env.scopes.index = index