From a45a6e5b414f6cb838005d35d093fbcdc2e72d41 Mon Sep 17 00:00:00 2001 From: AuroraV Date: Sun, 2 Jul 2023 12:37:35 +0800 Subject: [PATCH] feat:customFuncs add ctx support --- cli/cli.go | 9 ++++--- compiler.go | 4 +-- execute.go | 2 +- func.go | 23 +++++++++-------- option.go | 21 +++++++++------- option_function_test.go | 3 ++- option_iter_function_test.go | 3 ++- option_test.go | 49 ++++++++++++++++++------------------ 8 files changed, 61 insertions(+), 53 deletions(-) diff --git a/cli/cli.go b/cli/cli.go index a10ead55..83d08d52 100644 --- a/cli/cli.go +++ b/cli/cli.go @@ -2,6 +2,7 @@ package cli import ( + "context" "errors" "fmt" "io" @@ -242,8 +243,8 @@ Usage: gojq.WithFunction("debug", 0, 0, cli.funcDebug), gojq.WithFunction("stderr", 0, 0, cli.funcStderr), gojq.WithFunction("input_filename", 0, 0, - func(iter inputIter) func(any, []any) any { - return func(any, []any) any { + func(iter inputIter) func(context.Context, any, []any) any { + return func(context.Context, any, []any) any { if fname := iter.Name(); fname != "" && (len(args) > 0 || !opts.InputNull) { return fname } @@ -408,7 +409,7 @@ func (cli *cli) createMarshaler() marshaler { return f } -func (cli *cli) funcDebug(v any, _ []any) any { +func (cli *cli) funcDebug(_ context.Context, v any, _ []any) any { if err := newEncoder(false, 0).marshal([]any{"DEBUG:", v}, cli.errStream); err != nil { return err } @@ -418,7 +419,7 @@ func (cli *cli) funcDebug(v any, _ []any) any { return v } -func (cli *cli) funcStderr(v any, _ []any) any { +func (cli *cli) funcStderr(_ context.Context, v any, _ []any) any { if err := newEncoder(false, 0).marshal(v, cli.errStream); err != nil { return err } diff --git a/compiler.go b/compiler.go index de5f9a10..56801d45 100644 --- a/compiler.go +++ b/compiler.go @@ -804,8 +804,8 @@ func (c *compiler) compileBreak(label string) error { return nil } -func funcBreak(label string) func(any, []any) any { - return func(v any, _ []any) any { +func funcBreak(label string) func(context.Context, any, []any) any { + return func(_ context.Context, v any, _ []any) any { return &breakError{label, v} } } diff --git a/execute.go b/execute.go index dcf9d984..318537e4 100644 --- a/execute.go +++ b/execute.go @@ -189,7 +189,7 @@ loop: for i := 0; i < argcnt; i++ { args[i] = env.pop() } - w := v[0].(func(any, []any) any)(x, args) + w := v[0].(func(context.Context, any, []any) any)(env.ctx, x, args) if e, ok := w.(error); ok { if er, ok := e.(*exitCodeError); !ok || er.value != nil || er.halt { err = e diff --git a/func.go b/func.go index 6e8d1500..607e0897 100644 --- a/func.go +++ b/func.go @@ -1,6 +1,7 @@ package gojq import ( + "context" "encoding/base64" "encoding/json" "errors" @@ -33,7 +34,7 @@ const ( type function struct { argcount int iter bool - callback func(any, []any) any + callback func(context.Context, any, []any) any } func (fn function) accept(cnt int) bool { @@ -202,7 +203,7 @@ func init() { func argFunc0(f func(any) any) function { return function{ - argcount0, false, func(v any, _ []any) any { + argcount0, false, func(_ context.Context, v any, _ []any) any { return f(v) }, } @@ -210,7 +211,7 @@ func argFunc0(f func(any) any) function { func argFunc1(f func(_, _ any) any) function { return function{ - argcount1, false, func(v any, args []any) any { + argcount1, false, func(_ context.Context, v any, args []any) any { return f(v, args[0]) }, } @@ -218,7 +219,7 @@ func argFunc1(f func(_, _ any) any) function { func argFunc2(f func(_, _, _ any) any) function { return function{ - argcount2, false, func(v any, args []any) any { + argcount2, false, func(_ context.Context, v any, args []any) any { return f(v, args[0], args[1]) }, } @@ -226,7 +227,7 @@ func argFunc2(f func(_, _, _ any) any) function { func argFunc3(f func(_, _, _, _ any) any) function { return function{ - argcount3, false, func(v any, args []any) any { + argcount3, false, func(_ context.Context, v any, args []any) any { return f(v, args[0], args[1], args[2]) }, } @@ -718,7 +719,7 @@ func funcImplode(v any) any { return sb.String() } -func funcSplit(v any, args []any) any { +func funcSplit(_ context.Context, v any, args []any) any { s, ok := v.(string) if !ok { return &func0TypeError{"split", v} @@ -809,7 +810,7 @@ func funcFormat(v, x any) any { if f == nil { return &formatNotFoundError{format} } - return internalFuncs[f.Name].callback(v, nil) + return internalFuncs[f.Name].callback(context.Background(), v, nil) } var htmlEscaper = strings.NewReplacer( @@ -1101,7 +1102,7 @@ func clampIndex(i, min, max int) int { } } -func funcFlatten(v any, args []any) any { +func funcFlatten(_ context.Context, v any, args []any) any { vs, ok := values(v) if !ok { return &func0TypeError{"flatten", v} @@ -1145,7 +1146,7 @@ func (iter *rangeIter) Next() (any, bool) { return v, true } -func funcRange(_ any, xs []any) any { +func funcRange(_ context.Context, _ any, xs []any) any { for _, x := range xs { switch x.(type) { case int, float64, *big.Int: @@ -2048,7 +2049,7 @@ func funcCapture(v any) any { return w } -func funcError(v any, args []any) any { +func funcError(_ context.Context, v any, args []any) any { if len(args) > 0 { v = args[0] } @@ -2063,7 +2064,7 @@ func funcHalt(any) any { return &exitCodeError{nil, 0, true} } -func funcHaltError(v any, args []any) any { +func funcHaltError(_ context.Context, v any, args []any) any { code := 5 if len(args) > 0 { var ok bool diff --git a/option.go b/option.go index f1a110fa..07d14cc6 100644 --- a/option.go +++ b/option.go @@ -1,6 +1,9 @@ package gojq -import "fmt" +import ( + "context" + "fmt" +) // CompilerOption is a compiler option. type CompilerOption func(*compiler) @@ -39,7 +42,7 @@ func WithVariables(variables []string) CompilerOption { // function. If you want to emit multiple values, call the empty function, // accept a filter for its argument, or call another built-in function, then // use LoadInitModules of the module loader. -func WithFunction(name string, minarity, maxarity int, f func(any, []any) any) CompilerOption { +func WithFunction(name string, minarity, maxarity int, f func(context.Context, any, []any) any) CompilerOption { return withFunction(name, minarity, maxarity, false, f) } @@ -48,15 +51,15 @@ func WithFunction(name string, minarity, maxarity int, f func(any, []any) any) C // returns an Iter to emit multiple values. You cannot define both iterator and // non-iterator functions of the same name (with possibly different arities). // See also [NewIter], which can be used to convert values or an error to an Iter. -func WithIterFunction(name string, minarity, maxarity int, f func(any, []any) Iter) CompilerOption { +func WithIterFunction(name string, minarity, maxarity int, f func(context.Context, any, []any) Iter) CompilerOption { return withFunction(name, minarity, maxarity, true, - func(v any, args []any) any { - return f(v, args) + func(ctx context.Context, v any, args []any) any { + return f(ctx, v, args) }, ) } -func withFunction(name string, minarity, maxarity int, iter bool, f func(any, []any) any) CompilerOption { +func withFunction(name string, minarity, maxarity int, iter bool, f func(context.Context, any, []any) any) CompilerOption { if !(0 <= minarity && minarity <= maxarity && maxarity <= 30) { panic(fmt.Sprintf("invalid arity for %q: %d, %d", name, minarity, maxarity)) } @@ -71,11 +74,11 @@ func withFunction(name string, minarity, maxarity int, iter bool, f func(any, [] } c.customFuncs[name] = function{ argcount | fn.argcount, iter, - func(x any, xs []any) any { + func(ctx context.Context, x any, xs []any) any { if argcount&(1<