diff --git a/execution/engine/execution_engine_test.go b/execution/engine/execution_engine_test.go index c3ed66fd6..285edc553 100644 --- a/execution/engine/execution_engine_test.go +++ b/execution/engine/execution_engine_test.go @@ -749,7 +749,7 @@ func TestExecutionEngine_Execute(t *testing.T) { }, }, }, - expectedResponse: `{"errors":[{"message":"Failed to fetch from Subgraph at Path 'query', Reason: no data or errors in response."},{"message":"Cannot return null for non-nullable field 'Query.hero'.","path":["hero"]}],"data":null}`, + expectedResponse: `{"errors":[{"message":"Failed to fetch from Subgraph at Path 'query', Reason: invalid JSON."}],"data":null}`, })) t.Run("execute operation and apply input coercion for lists without variables", runWithoutError(ExecutionEngineTestCase{ diff --git a/go.work.sum b/go.work.sum index 7d3dfc865..895794308 100644 --- a/go.work.sum +++ b/go.work.sum @@ -26,6 +26,8 @@ github.com/andybalholm/brotli v1.1.0/go.mod h1:sms7XGricyQI9K10gOSf56VKKWS4oLer5 github.com/andybalholm/cascadia v1.3.2/go.mod h1:7gtRlve5FxPPgIgX36uWBX58OdBsSS6lUvCFb+h7KvU= github.com/antihax/optional v1.0.0 h1:xK2lYat7ZLaVVcIuj82J8kIro4V6kDe0AUDFboUCwcg= github.com/antihax/optional v1.0.0/go.mod h1:uupD/76wgC+ih3iEmQUL+0Ugr19nfwCT1kdvxnR2qWY= +github.com/barkyq/fastjson v0.0.0-20230118153732-bb1076612fd9 h1:gsI0tqI5IvQ3xIH4oXttNu0EMcBEUR+1RmfgayyGjVE= +github.com/barkyq/fastjson v0.0.0-20230118153732-bb1076612fd9/go.mod h1:CLCAqky6SMuOcxStkYQvblddUtoRxhYMGLrsQns1aXY= github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs= diff --git a/v2/go.mod b/v2/go.mod index 7aa0012a9..558f02986 100644 --- a/v2/go.mod +++ b/v2/go.mod @@ -8,6 +8,7 @@ require ( github.com/buger/jsonparser v1.1.1 github.com/cespare/xxhash/v2 v2.2.0 github.com/davecgh/go-spew v1.1.1 + github.com/goccy/go-json v0.10.2 github.com/golang/mock v1.6.0 github.com/google/go-cmp v0.6.0 github.com/google/uuid v1.6.0 @@ -24,6 +25,7 @@ require ( github.com/stretchr/testify v1.9.0 github.com/tidwall/gjson v1.17.0 github.com/tidwall/sjson v1.2.5 + github.com/valyala/fastjson v1.6.4 github.com/vektah/gqlparser/v2 v2.5.11 go.uber.org/atomic v1.11.0 go.uber.org/zap v1.26.0 @@ -68,3 +70,5 @@ require ( gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) + +replace github.com/valyala/fastjson v1.6.4 => github.com/barkyq/fastjson v0.0.0-20230118153732-bb1076612fd9 diff --git a/v2/go.sum b/v2/go.sum index 323e0c009..2a61ed0d0 100644 --- a/v2/go.sum +++ b/v2/go.sum @@ -9,6 +9,8 @@ github.com/andreyvit/diff v0.0.0-20170406064948-c7f18ee00883 h1:bvNMNQO63//z+xNg github.com/andreyvit/diff v0.0.0-20170406064948-c7f18ee00883/go.mod h1:rCTlJbsFo29Kk6CurOXKm700vrz8f0KW0JNfpkRJY/8= github.com/arbovm/levenshtein v0.0.0-20160628152529-48b4e1c0c4d0 h1:jfIu9sQUG6Ig+0+Ap1h4unLjW6YQJpKZVmUzxsD4E/Q= github.com/arbovm/levenshtein v0.0.0-20160628152529-48b4e1c0c4d0/go.mod h1:t2tdKJDJF9BV14lnkjHmOQgcvEKgtqs5a1N3LNdJhGE= +github.com/barkyq/fastjson v0.0.0-20230118153732-bb1076612fd9 h1:gsI0tqI5IvQ3xIH4oXttNu0EMcBEUR+1RmfgayyGjVE= +github.com/barkyq/fastjson v0.0.0-20230118153732-bb1076612fd9/go.mod h1:CLCAqky6SMuOcxStkYQvblddUtoRxhYMGLrsQns1aXY= github.com/buger/jsonparser v1.1.1 h1:2PnMjfWD7wBILjqQbt530v576A/cAbQvEW9gGIpYMUs= github.com/buger/jsonparser v1.1.1/go.mod h1:6RYKKt7H4d4+iWqouImQ9R2FZql3VbhNgx27UK13J/0= github.com/bytedance/sonic v1.5.0/go.mod h1:ED5hyg4y6t3/9Ku1R6dU/4KyJ48DZ4jPhfY1O2AihPM= diff --git a/v2/pkg/astjson/astjson.go b/v2/pkg/astjson/astjson.go index f3f451da0..1f0f1e98e 100644 --- a/v2/pkg/astjson/astjson.go +++ b/v2/pkg/astjson/astjson.go @@ -269,7 +269,7 @@ func (j *JSON) ParseArray(input []byte) (err error) { func (j *JSON) AppendAnyJSONBytes(input []byte) (ref int, err error) { if j.storage == nil { - j.storage = make([]byte, 0, 4*1024) + j.storage = make([]byte, 0, len(input)) } start := len(j.storage) j.storage = append(j.storage, input...) diff --git a/v2/pkg/astjson/astjson_test.go b/v2/pkg/astjson/astjson_test.go index 27987b9c9..19292e7d6 100644 --- a/v2/pkg/astjson/astjson_test.go +++ b/v2/pkg/astjson/astjson_test.go @@ -6,6 +6,8 @@ import ( "github.com/buger/jsonparser" "github.com/stretchr/testify/assert" + "github.com/valyala/fastjson" + "github.com/wundergraph/graphql-go-tools/v2/pkg/fastjsonext" ) func TestJSON_ParsePrint(t *testing.T) { @@ -401,6 +403,116 @@ func BenchmarkJSON_ParsePrint(b *testing.B) { } } +func BenchmarkFastJSON(b *testing.B) { + var p fastjson.Parser + input := []byte(`{"data":{"_entities":[{"stock":8},{"stock":2},{"stock":5}]}}`) + expectedOut := []byte(`{"_entities":[{"stock":8},{"stock":2},{"stock":5}]}`) + res := make([]byte, 0, 1024) + b.SetBytes(int64(len(input))) + b.ReportAllocs() + for i := 0; i < b.N; i++ { + v, err := p.ParseBytes(input) + if err != nil { + b.Fatal(err) + } + out := v.Get("data") + res = out.MarshalTo(res) + if !bytes.Equal(expectedOut, res) { + b.Fatal("not equal") + } + res = res[:0] + } +} + +func TestFastJsonMerge(t *testing.T) { + a, err := fastjson.ParseBytes([]byte(`{"a":1,"b":2}`)) + assert.NoError(t, err) + b, err := fastjson.ParseBytes([]byte(`{"c":3}`)) + assert.NoError(t, err) + merged, _ := fastjsonext.MergeValues(a, b) + out := merged.MarshalTo(nil) + assert.Equal(t, `{"a":1,"b":2,"c":3}`, string(out)) +} + +func TestFastJsonMergeNested(t *testing.T) { + a, err := fastjson.ParseBytes([]byte(`{"a":1,"b":2,"c":{"d":4,"e":4}}`)) + assert.NoError(t, err) + b, err := fastjson.ParseBytes([]byte(`{"c":{"e":5}}`)) + assert.NoError(t, err) + merged, _ := fastjsonext.MergeValues(a, b) + out := merged.MarshalTo(nil) + assert.Equal(t, `{"a":1,"b":2,"c":{"d":4,"e":5}}`, string(out)) +} + +func BenchmarkFastParse(b *testing.B) { + var p fastjson.Parser + + b.SetBytes(int64(len(bigJSON))) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + v, err := p.ParseBytes(bigJSON) + if err != nil { + b.Fatal(err) + } + if v == nil { + b.Fatal("nil") + } + } +} + +func BenchmarkParse(b *testing.B) { + fs := &JSON{} + b.SetBytes(int64(len(bigJSON))) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + fs.Reset() + ref, err := fs.AppendAnyJSONBytes(bigJSON) + if err != nil { + b.Fatal(err) + } + if ref == -1 { + b.Fatal("nil") + } + } +} + +func BenchmarkFastJsonMerge(t *testing.B) { + var ( + p1, p2, p3 fastjson.Parser + out = make([]byte, 0, 1024) + ) + first := []byte(`{"a":1,"b":2,"c":{"d":4,"e":5,"f":6,"g":7,"h":8,"i":9,"j":10,"k":11,"l":12,"m":13,"n":14,"o":15,"p":16,"q":17,"r":18,"s":19,"t":20,"u":21,"v":22,"w":23,"x":24,"y":25,"z":26}}`) + second := []byte(`{"c":{"e":5,"f":6,"g":7,"h":8,"i":9,"j":10,"k":11,"l":12,"m":13,"n":14,"o":15,"p":16,"q":17,"r":18,"s":19,"t":20,"u":21,"v":22,"w":23,"x":24,"y":25,"z":26}}`) + third := []byte(`{"c":{"e":6,"f":7,"g":8,"h":9,"i":10,"j":11,"k":true,"l":13,"m":"Cosmo Rocks!","n":15,"o":16,"p":17,"q":18,"r":19,"s":20,"t":21,"u":22,"v":23,"w":24,"x":25,"y":26,"z":28}}`) + expected := []byte(`{"a":1,"b":2,"c":{"d":4,"e":6,"f":7,"g":8,"h":9,"i":10,"j":11,"k":11,"l":13,"m":13,"n":15,"o":16,"p":17,"q":18,"r":19,"s":20,"t":21,"u":22,"v":23,"w":24,"x":25,"y":26,"z":28}}`) + t.SetBytes(int64(len(first) + len(second) + len(third))) + t.ReportAllocs() + t.ResetTimer() + for i := 0; i < t.N; i++ { + a, err := p1.ParseBytes(first) + if err != nil { + t.Fatal(err) + } + b, err := p2.ParseBytes(second) + if err != nil { + t.Fatal(err) + } + c, err := p3.ParseBytes(third) + if err != nil { + t.Fatal(err) + } + ab, _ := fastjsonext.MergeValues(a, b) + abc, _ := fastjsonext.MergeValues(ab, c) + out = abc.MarshalTo(out[:0]) + if !bytes.Equal(expected, out) { + t.Fatal("not equal") + } + + } +} + func BenchmarkJSON_MergeNodesNested(b *testing.B) { js := &JSON{} first := []byte(`{"a":1,"b":2,"c":{"d":4,"e":5,"f":6,"g":7,"h":8,"i":9,"j":10,"k":11,"l":12,"m":13,"n":14,"o":15,"p":16,"q":17,"r":18,"s":19,"t":20,"u":21,"v":22,"w":23,"x":24,"y":25,"z":26}}`) @@ -467,3 +579,7 @@ func BenchmarkJSON_MergeNodesWithPath(b *testing.B) { } } } + +var ( + bigJSON = []byte(`{"data":{"employees":[{"id":1,"details":{"forename":"Jens","surname":"Neuse"}},{"id":2,"details":{"forename":"Dustin","surname":"Deus"}},{"id":3,"details":{"forename":"Stefan","surname":"Avram"}},{"id":4,"details":{"forename":"Björn","surname":"Schwenzer"}},{"id":5,"details":{"forename":"Sergiy","surname":"Petrunin"}},{"id":7,"details":{"forename":"Suvij","surname":"Surya"}},{"id":8,"details":{"forename":"Nithin","surname":"Kumar"}},{"id":10,"details":{"forename":"Eelco","surname":"Wiersma"}},{"id":11,"details":{"forename":"Alexandra","surname":"Neuse"}},{"id":12,"details":{"forename":"David","surname":"Stutt"}}]}}`) +) diff --git a/v2/pkg/engine/datasource/graphql_datasource/graphql_datasource.go b/v2/pkg/engine/datasource/graphql_datasource/graphql_datasource.go index 9a8853589..1ce3f4c8b 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/graphql_datasource.go +++ b/v2/pkg/engine/datasource/graphql_datasource/graphql_datasource.go @@ -40,7 +40,7 @@ var ( SelectResponseErrorsPath: []string{"errors"}, } SingleEntityPostProcessingConfiguration = resolve.PostProcessingConfiguration{ - SelectResponseDataPath: []string{"data", "_entities", "[0]"}, + SelectResponseDataPath: []string{"data", "_entities", "0"}, SelectResponseErrorsPath: []string{"errors"}, } ) diff --git a/v2/pkg/engine/datasource/httpclient/nethttpclient.go b/v2/pkg/engine/datasource/httpclient/nethttpclient.go index 7d5648083..08a722d31 100644 --- a/v2/pkg/engine/datasource/httpclient/nethttpclient.go +++ b/v2/pkg/engine/datasource/httpclient/nethttpclient.go @@ -20,6 +20,7 @@ import ( "github.com/buger/jsonparser" "github.com/wundergraph/graphql-go-tools/v2/pkg/lexer/literal" + "github.com/wundergraph/graphql-go-tools/v2/pkg/pool" ) const ( @@ -132,7 +133,18 @@ func releaseBuffer(buf *bytes.Buffer) { requestBufferPool.Put(buf) } +type bodyHashContextKey struct{} + +func BodyHashFromContext(ctx context.Context) (uint64, bool) { + value := ctx.Value(bodyHashContextKey{}) + if value == nil { + return 0, false + } + return value.(uint64), true +} + func makeHTTPRequest(client *http.Client, ctx context.Context, url, method, headers, queryParams []byte, body io.Reader, enableTrace bool, out *bytes.Buffer, contentType string) (err error) { + request, err := http.NewRequestWithContext(ctx, string(method), string(url), body) if err != nil { return err @@ -243,7 +255,11 @@ func makeHTTPRequest(client *http.Client, ctx context.Context, url, method, head func Do(client *http.Client, ctx context.Context, requestInput []byte, out *bytes.Buffer) (err error) { url, method, body, headers, queryParams, enableTrace := requestInputParams(requestInput) - + h := pool.Hash64.Get() + _, _ = h.Write(body) + bodyHash := h.Sum64() + pool.Hash64.Put(h) + ctx = context.WithValue(ctx, bodyHashContextKey{}, bodyHash) return makeHTTPRequest(client, ctx, url, method, headers, queryParams, bytes.NewReader(body), enableTrace, out, ContentTypeJSON) } @@ -256,6 +272,10 @@ func DoMultipartForm( url, method, body, headers, queryParams, enableTrace := requestInputParams(requestInput) + h := pool.Hash64.Get() + defer pool.Hash64.Put(h) + _, _ = h.Write(body) + formValues := map[string]io.Reader{ "operations": bytes.NewReader(body), } @@ -273,6 +293,7 @@ func DoMultipartForm( fileMap = fmt.Sprintf(`%s, "%d" : ["variables.files.%d"]`, fileMap, i, i) } key := fmt.Sprintf("%d", i) + _, _ = h.WriteString(file.Path()) temporaryFile, err := os.Open(file.Path()) tempFiles = append(tempFiles, temporaryFile) if err != nil { @@ -299,6 +320,9 @@ func DoMultipartForm( } }() + bodyHash := h.Sum64() + ctx = context.WithValue(ctx, bodyHashContextKey{}, bodyHash) + return makeHTTPRequest(client, ctx, url, method, headers, queryParams, multipartBody, enableTrace, out, contentType) } diff --git a/v2/pkg/engine/resolve/authorization_test.go b/v2/pkg/engine/resolve/authorization_test.go index 0d6275668..a35a1315e 100644 --- a/v2/pkg/engine/resolve/authorization_test.go +++ b/v2/pkg/engine/resolve/authorization_test.go @@ -5,11 +5,12 @@ import ( "context" "encoding/json" "errors" - "github.com/stretchr/testify/require" "io" "sync/atomic" "testing" + "github.com/stretchr/testify/require" + "github.com/golang/mock/gomock" "github.com/stretchr/testify/assert" ) @@ -87,19 +88,19 @@ func TestAuthorization(t *testing.T) { return nil, nil }, func(ctx *Context, dataSourceID string, object json.RawMessage, coordinate GraphCoordinate) (result *AuthorizationDeny, err error) { if dataSourceID == "reviews" && coordinate.TypeName == "User" && coordinate.FieldName == "reviews" { - assert.Equal(t, `{"id":"1234","username":"Me","__typename":"User"}`, string(object)) + assert.Equal(t, `{"id":"1234","username":"Me","__typename":"User","reviews":[{"body":"A highly effective form of birth control.","product":{"upc":"top-1","__typename":"Product","data":{"name":"Trilby"}}},{"body":"Fedoras are one of the most fashionable hats around and can look great with a variety of outfits.","product":{"upc":"top-2","__typename":"Product","data":{"name":"Fedora"}}}]}`, string(object)) assertions.Add(1) } if dataSourceID == "reviews" && coordinate.TypeName == "Review" && coordinate.FieldName == "body" { - assert.Equal(t, `{"body":"A highly effective form of birth control."}`, string(object)) + assert.Equal(t, `{"body":"A highly effective form of birth control.","product":{"upc":"top-1","__typename":"Product","data":{"name":"Trilby"}}}`, string(object)) assertions.Add(1) } if dataSourceID == "reviews" && coordinate.TypeName == "Review" && coordinate.FieldName == "product" { - assert.Equal(t, `{"body":"A highly effective form of birth control."}`, string(object)) + assert.Equal(t, `{"body":"A highly effective form of birth control.","product":{"upc":"top-1","__typename":"Product","data":{"name":"Trilby"}}}`, string(object)) assertions.Add(1) } if dataSourceID == "products" && coordinate.TypeName == "Product" && coordinate.FieldName == "name" { - assert.Equal(t, `{"upc":"top-1","__typename":"Product"}`, string(object)) + assert.Equal(t, `{"upc":"top-1","__typename":"Product","data":{"name":"Trilby"}}`, string(object)) assertions.Add(1) } return nil, nil @@ -621,7 +622,7 @@ func generateTestFederationGraphQLResponse(t *testing.T, ctrl *gomock.Controller DataSource: reviewsService, PostProcessing: PostProcessingConfiguration{ SelectResponseErrorsPath: []string{"errors"}, - SelectResponseDataPath: []string{"data", "_entities", "[0]"}, + SelectResponseDataPath: []string{"data", "_entities", "0"}, }, }, }, @@ -907,7 +908,7 @@ func generateTestFederationGraphQLResponseWithoutAuthorizationRules(t *testing.T FetchConfiguration: FetchConfiguration{ DataSource: reviewsService, PostProcessing: PostProcessingConfiguration{ - SelectResponseDataPath: []string{"data", "_entities", "[0]"}, + SelectResponseDataPath: []string{"data", "_entities", "0"}, }, }, }, diff --git a/v2/pkg/engine/resolve/loader.go b/v2/pkg/engine/resolve/loader.go index 6d5a31038..2f2c54b69 100644 --- a/v2/pkg/engine/resolve/loader.go +++ b/v2/pkg/engine/resolve/loader.go @@ -4,25 +4,29 @@ import ( "bytes" "context" "crypto/tls" - "encoding/json" goerrors "errors" "fmt" "io" "net/http/httptrace" "slices" + "strconv" "strings" "sync" "time" + "github.com/goccy/go-json" + "github.com/buger/jsonparser" "github.com/cespare/xxhash/v2" "github.com/pkg/errors" "github.com/tidwall/gjson" + "github.com/valyala/fastjson" + "github.com/wundergraph/graphql-go-tools/v2/pkg/fastjsonext" + "github.com/wundergraph/graphql-go-tools/v2/pkg/internal/unsafebytes" "go.uber.org/atomic" "golang.org/x/sync/errgroup" "github.com/wundergraph/graphql-go-tools/v2/pkg/ast" - "github.com/wundergraph/graphql-go-tools/v2/pkg/astjson" "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/datasource/httpclient" ) @@ -63,9 +67,7 @@ func releaseLoaderBuf(buf *bytes.Buffer) { } type Loader struct { - data *astjson.JSON - dataRoot int - errorsRoot int + resolvable *Resolvable ctx *Context path []string info *GraphQLResponseInfo @@ -81,16 +83,12 @@ type Loader struct { func (l *Loader) Free() { l.info = nil l.ctx = nil - l.data = nil - l.dataRoot = -1 - l.errorsRoot = -1 + l.resolvable = nil l.path = l.path[:0] } func (l *Loader) LoadGraphQLResponseData(ctx *Context, response *GraphQLResponse, resolvable *Resolvable) (err error) { - l.data = resolvable.storage - l.dataRoot = resolvable.dataRoot - l.errorsRoot = resolvable.errorsRoot + l.resolvable = resolvable l.ctx = ctx l.info = response.Info @@ -100,10 +98,10 @@ func (l *Loader) LoadGraphQLResponseData(ctx *Context, response *GraphQLResponse fetchTree = response.Data } - return l.walkNode(fetchTree, []int{resolvable.dataRoot}) + return l.walkNode(fetchTree, []*fastjson.Value{resolvable.data}) } -func (l *Loader) walkNode(node Node, items []int) error { +func (l *Loader) walkNode(node Node, items []*fastjson.Value) error { switch n := node.(type) { case *Object: return l.walkObject(n, items) @@ -153,7 +151,7 @@ func (l *Loader) renderPath() string { return builder.String() } -func (l *Loader) walkObject(object *Object, parentItems []int) (err error) { +func (l *Loader) walkObject(object *Object, parentItems []*fastjson.Value) (err error) { l.pushPath(object.Path) defer l.popPath(object.Path) objectItems := l.selectNodeItems(parentItems, object.Path) @@ -172,7 +170,7 @@ func (l *Loader) walkObject(object *Object, parentItems []int) (err error) { return nil } -func (l *Loader) walkArray(array *Array, parentItems []int) error { +func (l *Loader) walkArray(array *Array, parentItems []*fastjson.Value) error { l.pushPath(array.Path) l.pushArrayPath() nodeItems := l.selectNodeItems(parentItems, array.Path) @@ -182,7 +180,7 @@ func (l *Loader) walkArray(array *Array, parentItems []int) error { return err } -func (l *Loader) selectNodeItems(parentItems []int, path []string) (items []int) { +func (l *Loader) selectNodeItems(parentItems []*fastjson.Value, path []string) (items []*fastjson.Value) { if parentItems == nil { return nil } @@ -190,44 +188,52 @@ func (l *Loader) selectNodeItems(parentItems []int, path []string) (items []int) return parentItems } if len(parentItems) == 1 { - field := l.data.Get(parentItems[0], path) - if field == -1 { + field := parentItems[0].Get(path...) + if field == nil { return nil } - if l.data.Nodes[field].Kind == astjson.NodeKindArray { - return l.data.Nodes[field].ArrayValues + if field.Type() == fastjson.TypeArray { + return field.GetArray() } - return []int{field} + return []*fastjson.Value{field} } - items = make([]int, 0, len(parentItems)) + items = make([]*fastjson.Value, 0, len(parentItems)) for _, parent := range parentItems { - field := l.data.Get(parent, path) - if field == -1 { + field := parent.Get(path...) + if field == nil { continue } - if l.data.Nodes[field].Kind == astjson.NodeKindArray { - items = append(items, l.data.Nodes[field].ArrayValues...) - } else { - items = append(items, field) + if field.Type() == fastjson.TypeArray { + items = append(items, field.GetArray()...) + continue } + items = append(items, field) } return } -func (l *Loader) itemsData(items []int, out io.Writer) error { +func (l *Loader) itemsData(items []*fastjson.Value, out io.Writer) { if len(items) == 0 { - return nil + return } if len(items) == 1 { - return l.data.PrintNode(l.data.Nodes[items[0]], out) + data := items[0].MarshalTo(nil) + _, _ = out.Write(data) + return + } + _, _ = out.Write(lBrack) + var data []byte + for i, item := range items { + if i != 0 { + _, _ = out.Write(comma) + } + data = item.MarshalTo(data[:0]) + _, _ = out.Write(data) } - return l.data.PrintNode(astjson.Node{ - Kind: astjson.NodeKindArray, - ArrayValues: items, - }, out) + _, _ = out.Write(rBrack) } -func (l *Loader) resolveAndMergeFetch(fetch Fetch, items []int) error { +func (l *Loader) resolveAndMergeFetch(fetch Fetch, items []*fastjson.Value) error { switch f := fetch.(type) { case *SingleFetch: res := &result{ @@ -357,7 +363,7 @@ func (l *Loader) resolveAndMergeFetch(fetch Fetch, items []int) error { return nil } -func (l *Loader) loadFetch(ctx context.Context, fetch Fetch, items []int, res *result) error { +func (l *Loader) loadFetch(ctx context.Context, fetch Fetch, items []*fastjson.Value, res *result) error { switch f := fetch.(type) { case *SingleFetch: res.out = acquireLoaderBuf() @@ -405,7 +411,7 @@ func (l *Loader) loadFetch(ctx context.Context, fetch Fetch, items []int, res *r return nil } -func (l *Loader) mergeResult(res *result, items []int) error { +func (l *Loader) mergeResult(res *result, items []*fastjson.Value) error { defer releaseLoaderBuf(res.out) if res.err != nil { return l.renderErrorsFailedToFetch(res, failedToFetchNoReason) @@ -415,12 +421,12 @@ func (l *Loader) mergeResult(res *result, items []int) error { if err != nil { return err } + trueValue := fastjson.MustParse(`true`) + skipErrorsPath := make([]string, len(res.postProcessing.MergePath)+1) + copy(skipErrorsPath, res.postProcessing.MergePath) + skipErrorsPath[len(skipErrorsPath)-1] = "__skipErrors" for _, item := range items { - l.data.Nodes = append(l.data.Nodes, astjson.Node{ - Kind: astjson.NodeKindNullSkipError, - }) - ref := len(l.data.Nodes) - 1 - l.data.MergeNodesWithPath(item, ref, res.postProcessing.MergePath) + fastjsonext.SetValue(item, trueValue, skipErrorsPath...) } return nil } @@ -429,12 +435,12 @@ func (l *Loader) mergeResult(res *result, items []int) error { if err != nil { return err } + trueValue := fastjson.MustParse(`true`) + skipErrorsPath := make([]string, len(res.postProcessing.MergePath)+1) + copy(skipErrorsPath, res.postProcessing.MergePath) + skipErrorsPath[len(skipErrorsPath)-1] = "__skipErrors" for _, item := range items { - l.data.Nodes = append(l.data.Nodes, astjson.Node{ - Kind: astjson.NodeKindNullSkipError, - }) - ref := len(l.data.Nodes) - 1 - l.data.MergeNodesWithPath(item, ref, res.postProcessing.MergePath) + fastjsonext.SetValue(item, trueValue, skipErrorsPath...) } return nil } @@ -444,8 +450,8 @@ func (l *Loader) mergeResult(res *result, items []int) error { if res.out.Len() == 0 { return l.renderErrorsFailedToFetch(res, emptyGraphQLResponse) } - - node, err := l.data.AppendAnyJSONBytes(res.out.Bytes()) + l.resolvable.maxSize += res.out.Len() + value, err := l.resolvable.parseJSON(res.out.Bytes()) if err != nil { return l.renderErrorsFailedToFetch(res, invalidGraphQLResponse) } @@ -454,11 +460,12 @@ func (l *Loader) mergeResult(res *result, items []int) error { // We check if the subgraph response has errors if res.postProcessing.SelectResponseErrorsPath != nil { - ref := l.data.Get(node, res.postProcessing.SelectResponseErrorsPath) - if ref != -1 { - hasErrors = l.data.NodeIsDefined(ref) && len(l.data.Nodes[ref].ArrayValues) > 0 + errorsValue := value.Get(res.postProcessing.SelectResponseErrorsPath...) + if fastjsonext.ValueIsNonNull(errorsValue) { + errorObjects := errorsValue.GetArray() + hasErrors = len(errorObjects) > 0 // Look for errors in the response and merge them into the errors array - err = l.mergeErrors(res, ref) + err = l.mergeErrors(res, errorsValue, errorObjects) if err != nil { return errors.WithStack(err) } @@ -467,9 +474,9 @@ func (l *Loader) mergeResult(res *result, items []int) error { // We also check if any data is there to processed if res.postProcessing.SelectResponseDataPath != nil { - node = l.data.Get(node, res.postProcessing.SelectResponseDataPath) + value = value.Get(res.postProcessing.SelectResponseDataPath...) // Check if the not set or null - if !l.data.NodeIsDefined(node) { + if fastjsonext.ValueIsNull(value) { // If we didn't get any data nor errors, we return an error because the response is invalid // Returning an error here also avoids the need to walk over it later. if !hasErrors { @@ -478,49 +485,46 @@ func (l *Loader) mergeResult(res *result, items []int) error { // no data return nil } - - // If the data is set, it must be an object according to GraphQL over HTTP spec - if l.data.Nodes[l.data.RootNode].Kind != astjson.NodeKindObject { - return l.renderErrorsFailedToFetch(res, invalidGraphQLResponseShape) - } } withPostProcessing := res.postProcessing.ResponseTemplate != nil if withPostProcessing && len(items) <= 1 { - postProcessed := acquireLoaderBuf() - defer releaseLoaderBuf(postProcessed) - res.out.Reset() - err = l.data.PrintNode(l.data.Nodes[node], res.out) - if err != nil { - return errors.WithStack(err) - } - err = res.postProcessing.ResponseTemplate.Render(l.ctx, res.out.Bytes(), postProcessed) + postProcessed := &bytes.Buffer{} + valueJSON := value.MarshalTo(nil) + err = res.postProcessing.ResponseTemplate.Render(l.ctx, valueJSON, postProcessed) if err != nil { return errors.WithStack(err) } - node, err = l.data.AppendObject(postProcessed.Bytes()) + value, err = l.resolvable.parseJSON(postProcessed.Bytes()) if err != nil { return errors.WithStack(err) } } if len(items) == 0 { - l.data.RootNode = node + // If the data is set, it must be an object according to GraphQL over HTTP spec + if value.Type() != fastjson.TypeObject { + return l.renderErrorsFailedToFetch(res, invalidGraphQLResponseShape) + } + l.resolvable.data = value return nil } if len(items) == 1 && res.batchStats == nil { - l.data.MergeNodesWithPath(items[0], node, res.postProcessing.MergePath) + fastjsonext.MergeValuesWithPath(items[0], value, res.postProcessing.MergePath...) return nil } + batch := value.GetArray() + if batch == nil { + return l.renderErrorsFailedToFetch(res, invalidGraphQLResponseShape) + } if res.batchStats != nil { var ( postProcessed *bytes.Buffer rendered *bytes.Buffer + itemBuffer = make([]byte, 0, 1024) ) if withPostProcessing { - postProcessed = acquireLoaderBuf() - defer releaseLoaderBuf(postProcessed) - rendered = acquireLoaderBuf() - defer releaseLoaderBuf(rendered) + postProcessed = &bytes.Buffer{} + rendered = &bytes.Buffer{} for i, stats := range res.batchStats { postProcessed.Reset() rendered.Reset() @@ -535,10 +539,8 @@ func (l *Loader) mergeResult(res *result, items []int) error { addComma = true continue } - err = l.data.PrintNode(l.data.Nodes[l.data.Nodes[node].ArrayValues[item]], rendered) - if err != nil { - return errors.WithStack(err) - } + itemBuffer = batch[item].MarshalTo(itemBuffer[:0]) + _, _ = rendered.Write(itemBuffer) addComma = true } _, _ = rendered.Write(rBrack) @@ -546,11 +548,8 @@ func (l *Loader) mergeResult(res *result, items []int) error { if err != nil { return errors.WithStack(err) } - nodeProcessed, err := l.data.AppendObject(postProcessed.Bytes()) - if err != nil { - return errors.WithStack(err) - } - l.data.MergeNodesWithPath(items[i], nodeProcessed, res.postProcessing.MergePath) + nodeProcessed := fastjson.MustParseBytes(postProcessed.Bytes()) + fastjsonext.MergeValuesWithPath(items[i], nodeProcessed, res.postProcessing.MergePath...) } } else { for i, stats := range res.batchStats { @@ -558,13 +557,13 @@ func (l *Loader) mergeResult(res *result, items []int) error { if item == -1 { continue } - l.data.MergeNodesWithPath(items[i], l.data.Nodes[node].ArrayValues[item], res.postProcessing.MergePath) + fastjsonext.MergeValuesWithPath(items[i], batch[item], res.postProcessing.MergePath...) } } } } else { for i, item := range items { - l.data.MergeNodesWithPath(item, l.data.Nodes[node].ArrayValues[i], res.postProcessing.MergePath) + fastjsonext.MergeValuesWithPath(item, batch[i], res.postProcessing.MergePath...) } } return nil @@ -618,30 +617,16 @@ func (l *Loader) renderErrorsInvalidInput(out *bytes.Buffer) error { return nil } -func (l *Loader) mergeErrors(res *result, ref int) error { - if l.errorsRoot == -1 { - l.data.Nodes = append(l.data.Nodes, astjson.Node{ - Kind: astjson.NodeKindArray, - }) - l.errorsRoot = len(l.data.Nodes) - 1 - } - +func (l *Loader) mergeErrors(res *result, value *fastjson.Value, values []*fastjson.Value) error { path := l.renderPath() - responseErrorsBuf := acquireLoaderBuf() - defer releaseLoaderBuf(responseErrorsBuf) - - // print them into the buffer to be able to parse them - err := l.data.PrintNode(l.data.Nodes[ref], responseErrorsBuf) - if err != nil { - return err - } - // Serialize subgraph errors from the response - // and append them to the subgraph downsteam errors - if len(l.data.Nodes[ref].ArrayValues) > 0 { - graphqlErrors := make([]GraphQLError, 0, len(l.data.Nodes[ref].ArrayValues)) - err = json.Unmarshal(responseErrorsBuf.Bytes(), &graphqlErrors) + // and append them to the subgraph downstream errors + if len(values) > 0 { + // print them into the buffer to be able to parse them + errorsJSON := value.MarshalTo(nil) + graphqlErrors := make([]GraphQLError, 0, len(values)) + err := json.Unmarshal(errorsJSON, &graphqlErrors) if err != nil { return errors.WithStack(err) } @@ -656,141 +641,97 @@ func (l *Loader) mergeErrors(res *result, ref int) error { l.ctx.appendSubgraphError(goerrors.Join(res.err, subgraphError)) } - l.optionallyOmitErrorExtensions(ref) - l.optionallyOmitErrorLocations(ref) - l.optionallyRewriteErrorPaths(ref) + l.optionallyOmitErrorExtensions(values) + l.optionallyOmitErrorLocations(values) + l.optionallyRewriteErrorPaths(values) if l.subgraphErrorPropagationMode == SubgraphErrorPropagationModePassThrough { - l.data.MergeArrays(l.errorsRoot, ref) - return nil - } - - errorObject, err := l.data.AppendObject([]byte(l.renderSubgraphBaseError(res.subgraphName, path, failedToFetchNoReason))) - if err != nil { - return errors.WithStack(err) - } - - if !l.propagateSubgraphErrors { - l.data.Nodes[l.errorsRoot].ArrayValues = append(l.data.Nodes[l.errorsRoot].ArrayValues, errorObject) + fastjsonext.MergeValues(l.resolvable.errors, value) return nil } - extensions := l.data.Get(errorObject, []string{"extensions"}) - if extensions == -1 { - extensions, _ = l.data.AppendObject([]byte(`{}`)) - _ = l.data.SetObjectField(errorObject, extensions, "extensions") + errorObject := fastjson.MustParse(l.renderSubgraphBaseError(res.subgraphName, path, failedToFetchNoReason)) + if l.propagateSubgraphErrors { + fastjsonext.SetValue(errorObject, value, "extensions", "errors") } - _ = l.data.SetObjectField(extensions, ref, "errors") l.setSubgraphStatusCode(errorObject, res.statusCode) - l.data.Nodes[l.errorsRoot].ArrayValues = append(l.data.Nodes[l.errorsRoot].ArrayValues, errorObject) - + fastjsonext.AppendToArray(l.resolvable.errors, errorObject) return nil } -func (l *Loader) optionallyOmitErrorExtensions(ref int) { +func (l *Loader) optionallyOmitErrorExtensions(values []*fastjson.Value) { if !l.omitSubgraphErrorExtensions { return } -WithNextError: - for _, i := range l.data.Nodes[ref].ArrayValues { - if l.data.Nodes[i].Kind != astjson.NodeKindObject { - continue - } - fields := l.data.Nodes[i].ObjectFields - for j, k := range fields { - key := l.data.ObjectFieldKey(k) - if !bytes.Equal(key, literalExtensions) { - continue - } - l.data.Nodes[i].ObjectFields = append(fields[:j], fields[j+1:]...) - continue WithNextError + for _, value := range values { + if value.Exists("extensions") { + value.Del("extensions") } } } -func (l *Loader) optionallyOmitErrorLocations(ref int) { +func (l *Loader) optionallyOmitErrorLocations(values []*fastjson.Value) { if !l.omitSubgraphErrorLocations { return } -WithNextError: - for _, i := range l.data.Nodes[ref].ArrayValues { - if l.data.Nodes[i].Kind != astjson.NodeKindObject { - continue - } - fields := l.data.Nodes[i].ObjectFields - for j, k := range fields { - key := l.data.ObjectFieldKey(k) - if !bytes.Equal(key, literalLocations) { - continue - } - l.data.Nodes[i].ObjectFields = append(fields[:j], fields[j+1:]...) - continue WithNextError + for _, value := range values { + if value.Exists("locations") { + value.Del("locations") } } } -func (l *Loader) optionallyRewriteErrorPaths(ref int) { +func (l *Loader) optionallyRewriteErrorPaths(values []*fastjson.Value) { if !l.rewriteSubgraphErrorPaths { return } - pathPrefix := make([]int, len(l.path)) - for i := range l.path { - str := l.data.AppendString(l.path[i]) - pathPrefix[i] = str - } + pathPrefix := make([]string, len(l.path)) + copy(pathPrefix, l.path) // remove the trailing @ in case we're in an array as it looks weird in the path // errors, like fetches, are attached to objects, not arrays if len(l.path) != 0 && l.path[len(l.path)-1] == "@" { pathPrefix = pathPrefix[:len(pathPrefix)-1] } -WithNextError: - for _, i := range l.data.Nodes[ref].ArrayValues { - if l.data.Nodes[i].Kind != astjson.NodeKindObject { + for _, value := range values { + errorPath := value.Get("path") + if fastjsonext.ValueIsNull(errorPath) { continue } - fields := l.data.Nodes[i].ObjectFields - for _, j := range fields { - key := l.data.ObjectFieldKey(j) - if !bytes.Equal(key, literalPath) { - continue - } - value := l.data.ObjectFieldValue(j) - if l.data.Nodes[value].Kind != astjson.NodeKindArray { - continue - } - if len(l.data.Nodes[value].ArrayValues) == 0 { - continue WithNextError - } - v := l.data.Nodes[value].ArrayValues[0] - if l.data.Nodes[v].Kind != astjson.NodeKindString { - continue WithNextError - } - elem := l.data.Nodes[v].ValueBytes(l.data) - if !bytes.Equal(elem, literalUnderscoreEntities) { - continue WithNextError + if errorPath.Type() != fastjson.TypeArray { + continue + } + pathItems := errorPath.GetArray() + if len(pathItems) == 0 { + continue + } + for i, item := range pathItems { + if unsafebytes.BytesToString(item.GetStringBytes()) == "_entities" { + // rewrite the path to pathPrefix + pathItems after _entities + newPath := make([]string, 0, len(pathPrefix)+len(pathItems)-i) + newPath = append(newPath, pathPrefix...) + for j := i + 1; j < len(pathItems); j++ { + newPath = append(newPath, unsafebytes.BytesToString(pathItems[j].GetStringBytes())) + } + newPathJSON, _ := json.Marshal(newPath) + value.Set("path", fastjson.MustParseBytes(newPathJSON)) + break } - l.data.Nodes[value].ArrayValues = append(pathPrefix, l.data.Nodes[value].ArrayValues[1:]...) } } } -func (l *Loader) setSubgraphStatusCode(errorObjectRef, statusCode int) { +func (l *Loader) setSubgraphStatusCode(errorObject *fastjson.Value, statusCode int) { if !l.propagateSubgraphStatusCodes { return } if statusCode == 0 { return } - ref := l.data.AppendInt(statusCode) - if ref == -1 { + v, err := fastjson.Parse(strconv.FormatInt(int64(statusCode), 10)) + if err != nil { return } - extensions := l.data.Get(errorObjectRef, []string{"extensions"}) - if extensions == -1 { - extensions, _ = l.data.AppendObject([]byte(`{}`)) - _ = l.data.SetObjectField(errorObjectRef, extensions, "extensions") - } - _ = l.data.SetObjectField(extensions, ref, "statusCode") + fastjsonext.SetValue(errorObject, v, "extensions", "statusCode") } const ( @@ -803,12 +744,12 @@ const ( func (l *Loader) renderErrorsFailedToFetch(res *result, reason string) error { path := l.renderPath() l.ctx.appendSubgraphError(goerrors.Join(res.err, NewSubgraphError(res.subgraphName, path, reason, res.statusCode))) - errorObject, err := l.data.AppendObject([]byte(l.renderSubgraphBaseError(res.subgraphName, path, reason))) + errorObject, err := fastjson.Parse(l.renderSubgraphBaseError(res.subgraphName, path, reason)) if err != nil { - return errors.WithStack(err) + return err } l.setSubgraphStatusCode(errorObject, res.statusCode) - l.data.Nodes[l.errorsRoot].ArrayValues = append(l.data.Nodes[l.errorsRoot].ArrayValues, errorObject) + fastjsonext.AppendToArray(l.resolvable.errors, errorObject) return nil } @@ -833,33 +774,21 @@ func (l *Loader) renderAuthorizationRejectedErrors(res *result) error { if res.subgraphName == "" { for _, reason := range res.authorizationRejectedReasons { if reason == "" { - errorObject, err := l.data.AppendObject([]byte(fmt.Sprintf(`{"message":"Unauthorized Subgraph request at Path '%s'."}`, path))) - if err != nil { - return errors.WithStack(err) - } - l.data.Nodes[l.errorsRoot].ArrayValues = append(l.data.Nodes[l.errorsRoot].ArrayValues, errorObject) + errorObject := fastjson.MustParse(fmt.Sprintf(`{"message":"Unauthorized Subgraph request at Path '%s'."}`, path)) + fastjsonext.AppendToArray(l.resolvable.errors, errorObject) } else { - errorObject, err := l.data.AppendObject([]byte(fmt.Sprintf(`{"message":"Unauthorized Subgraph request at Path '%s', Reason: %s."}`, path, reason))) - if err != nil { - return errors.WithStack(err) - } - l.data.Nodes[l.errorsRoot].ArrayValues = append(l.data.Nodes[l.errorsRoot].ArrayValues, errorObject) + errorObject := fastjson.MustParse(fmt.Sprintf(`{"message":"Unauthorized Subgraph request at Path '%s', Reason: %s."}`, path, reason)) + fastjsonext.AppendToArray(l.resolvable.errors, errorObject) } } } else { for _, reason := range res.authorizationRejectedReasons { if reason == "" { - errorObject, err := l.data.AppendObject([]byte(fmt.Sprintf(`{"message":"Unauthorized request to Subgraph '%s' at Path '%s'."}`, res.subgraphName, path))) - if err != nil { - return errors.WithStack(err) - } - l.data.Nodes[l.errorsRoot].ArrayValues = append(l.data.Nodes[l.errorsRoot].ArrayValues, errorObject) + errorObject := fastjson.MustParse(fmt.Sprintf(`{"message":"Unauthorized request to Subgraph '%s' at Path '%s'."}`, res.subgraphName, path)) + fastjsonext.AppendToArray(l.resolvable.errors, errorObject) } else { - errorObject, err := l.data.AppendObject([]byte(fmt.Sprintf(`{"message":"Unauthorized request to Subgraph '%s' at Path '%s', Reason: %s."}`, res.subgraphName, path, reason))) - if err != nil { - return errors.WithStack(err) - } - l.data.Nodes[l.errorsRoot].ArrayValues = append(l.data.Nodes[l.errorsRoot].ArrayValues, errorObject) + errorObject := fastjson.MustParse(fmt.Sprintf(`{"message":"Unauthorized request to Subgraph '%s' at Path '%s', Reason: %s."}`, res.subgraphName, path, reason)) + fastjsonext.AppendToArray(l.resolvable.errors, errorObject) } } } @@ -872,31 +801,19 @@ func (l *Loader) renderRateLimitRejectedErrors(res *result) error { if res.subgraphName == "" { if res.rateLimitRejectedReason == "" { - errorObject, err := l.data.AppendObject([]byte(fmt.Sprintf(`{"message":"Rate limit exceeded for Subgraph request at Path '%s'."}`, path))) - if err != nil { - return errors.WithStack(err) - } - l.data.Nodes[l.errorsRoot].ArrayValues = append(l.data.Nodes[l.errorsRoot].ArrayValues, errorObject) + errorObject := fastjson.MustParse(fmt.Sprintf(`{"message":"Rate limit exceeded for Subgraph request at Path '%s'."}`, path)) + fastjsonext.AppendToArray(l.resolvable.errors, errorObject) } else { - errorObject, err := l.data.AppendObject([]byte(fmt.Sprintf(`{"message":"Rate limit exceeded for Subgraph request at Path '%s', Reason: %s."}`, path, res.rateLimitRejectedReason))) - if err != nil { - return errors.WithStack(err) - } - l.data.Nodes[l.errorsRoot].ArrayValues = append(l.data.Nodes[l.errorsRoot].ArrayValues, errorObject) + errorObject := fastjson.MustParse(fmt.Sprintf(`{"message":"Rate limit exceeded for Subgraph request at Path '%s', Reason: %s."}`, path, res.rateLimitRejectedReason)) + fastjsonext.AppendToArray(l.resolvable.errors, errorObject) } } else { if res.rateLimitRejectedReason == "" { - errorObject, err := l.data.AppendObject([]byte(fmt.Sprintf(`{"message":"Rate limit exceeded for Subgraph '%s' at Path '%s'."}`, res.subgraphName, path))) - if err != nil { - return errors.WithStack(err) - } - l.data.Nodes[l.errorsRoot].ArrayValues = append(l.data.Nodes[l.errorsRoot].ArrayValues, errorObject) + errorObject := fastjson.MustParse(fmt.Sprintf(`{"message":"Rate limit exceeded for Subgraph '%s' at Path '%s'."}`, res.subgraphName, path)) + fastjsonext.AppendToArray(l.resolvable.errors, errorObject) } else { - errorObject, err := l.data.AppendObject([]byte(fmt.Sprintf(`{"message":"Rate limit exceeded for Subgraph '%s' at Path '%s', Reason: %s."}`, res.subgraphName, path, res.rateLimitRejectedReason))) - if err != nil { - return errors.WithStack(err) - } - l.data.Nodes[l.errorsRoot].ArrayValues = append(l.data.Nodes[l.errorsRoot].ArrayValues, errorObject) + errorObject := fastjson.MustParse(fmt.Sprintf(`{"message":"Rate limit exceeded for Subgraph '%s' at Path '%s', Reason: %s."}`, res.subgraphName, path, res.rateLimitRejectedReason)) + fastjsonext.AppendToArray(l.resolvable.errors, errorObject) } } return nil @@ -967,9 +884,14 @@ func (l *Loader) validatePreFetch(input []byte, info *FetchInfo, res *result) (a } var ( - singleFetchPool = sync.Pool{} - singleFetchInputSize = atomic.NewInt32(32) - singleFetchPreparedInputSize = atomic.NewInt32(32) + singleFetchPool = sync.Pool{ + New: func() any { + return &singleFetchBuffer{ + input: &bytes.Buffer{}, + preparedInput: &bytes.Buffer{}, + } + }, + } ) type singleFetchBuffer struct { @@ -978,32 +900,20 @@ type singleFetchBuffer struct { } func acquireSingleFetchBuffer() *singleFetchBuffer { - buf := singleFetchPool.Get() - if buf == nil { - return &singleFetchBuffer{ - input: bytes.NewBuffer(make([]byte, 0, int(singleFetchInputSize.Load()))), - preparedInput: bytes.NewBuffer(make([]byte, 0, int(singleFetchPreparedInputSize.Load()))), - } - } - return buf.(*singleFetchBuffer) + return singleFetchPool.Get().(*singleFetchBuffer) } func releaseSingleFetchBuffer(buf *singleFetchBuffer) { - singleFetchInputSize.Store(int32(buf.input.Cap())) - singleFetchPreparedInputSize.Store(int32(buf.preparedInput.Cap())) buf.input.Reset() buf.preparedInput.Reset() singleFetchPool.Put(buf) } -func (l *Loader) loadSingleFetch(ctx context.Context, fetch *SingleFetch, items []int, res *result) error { +func (l *Loader) loadSingleFetch(ctx context.Context, fetch *SingleFetch, items []*fastjson.Value, res *result) error { res.init(fetch.PostProcessing, fetch.Info) buf := acquireSingleFetchBuffer() defer releaseSingleFetchBuffer(buf) - err := l.itemsData(items, buf.input) - if err != nil { - return errors.WithStack(err) - } + l.itemsData(items, buf.input) if l.ctx.TracingOptions.Enable { fetch.Trace = &DataSourceLoadTrace{} if !l.ctx.TracingOptions.ExcludeRawInputData { @@ -1012,7 +922,7 @@ func (l *Loader) loadSingleFetch(ctx context.Context, fetch *SingleFetch, items fetch.Trace.RawInputData = inputCopy } } - err = fetch.InputTemplate.Render(l.ctx, buf.input.Bytes(), buf.preparedInput) + err := fetch.InputTemplate.Render(l.ctx, buf.input.Bytes(), buf.preparedInput) if err != nil { return l.renderErrorsInvalidInput(res.out) } @@ -1029,48 +939,39 @@ func (l *Loader) loadSingleFetch(ctx context.Context, fetch *SingleFetch, items } var ( - entityFetchPool = sync.Pool{} - entityFetchItemDataSize = atomic.NewInt32(32) - entityFetchPreparedInputSize = atomic.NewInt32(32) - entityFetchItemSize = atomic.NewInt32(32) + entityFetchPool = sync.Pool{ + New: func() any { + return &entityFetchBuffer{ + item: &bytes.Buffer{}, + itemData: &bytes.Buffer{}, + preparedInput: &bytes.Buffer{}, + } + }, + } ) type entityFetchBuffer struct { + item *bytes.Buffer itemData *bytes.Buffer preparedInput *bytes.Buffer - item *bytes.Buffer } func acquireEntityFetchBuffer() *entityFetchBuffer { - buf := entityFetchPool.Get() - if buf == nil { - return &entityFetchBuffer{ - itemData: bytes.NewBuffer(make([]byte, 0, int(entityFetchItemDataSize.Load()))), - preparedInput: bytes.NewBuffer(make([]byte, 0, int(entityFetchPreparedInputSize.Load()))), - item: bytes.NewBuffer(make([]byte, 0, int(entityFetchItemSize.Load()))), - } - } - return buf.(*entityFetchBuffer) + return entityFetchPool.Get().(*entityFetchBuffer) } func releaseEntityFetchBuffer(buf *entityFetchBuffer) { - entityFetchItemDataSize.Store(int32(buf.itemData.Cap())) - entityFetchPreparedInputSize.Store(int32(buf.preparedInput.Cap())) - entityFetchItemSize.Store(int32(buf.item.Cap())) + buf.item.Reset() buf.itemData.Reset() buf.preparedInput.Reset() - buf.item.Reset() entityFetchPool.Put(buf) } -func (l *Loader) loadEntityFetch(ctx context.Context, fetch *EntityFetch, items []int, res *result) error { +func (l *Loader) loadEntityFetch(ctx context.Context, fetch *EntityFetch, items []*fastjson.Value, res *result) error { res.init(fetch.PostProcessing, fetch.Info) buf := acquireEntityFetchBuffer() defer releaseEntityFetchBuffer(buf) - err := l.itemsData(items, buf.itemData) - if err != nil { - return errors.WithStack(err) - } + l.itemsData(items, buf.itemData) if l.ctx.TracingOptions.Enable { fetch.Trace = &DataSourceLoadTrace{} @@ -1083,7 +984,7 @@ func (l *Loader) loadEntityFetch(ctx context.Context, fetch *EntityFetch, items var undefinedVariables []string - err = fetch.Input.Header.RenderAndCollectUndefinedVariables(l.ctx, nil, buf.preparedInput, &undefinedVariables) + err := fetch.Input.Header.RenderAndCollectUndefinedVariables(l.ctx, nil, buf.preparedInput, &undefinedVariables) if err != nil { return errors.WithStack(err) } @@ -1150,13 +1051,11 @@ func (l *Loader) loadEntityFetch(ctx context.Context, fetch *EntityFetch, items var ( batchEntityFetchPool = sync.Pool{} batchEntityPreparedInputSize = atomic.NewInt32(32) - batchEntityItemDataSize = atomic.NewInt32(32) batchEntityItemInputSize = atomic.NewInt32(32) ) type batchEntityFetchBuffer struct { preparedInput *bytes.Buffer - itemData *bytes.Buffer itemInput *bytes.Buffer keyGen *xxhash.Digest } @@ -1166,7 +1065,6 @@ func acquireBatchEntityFetchBuffer() *batchEntityFetchBuffer { if buf == nil { return &batchEntityFetchBuffer{ preparedInput: bytes.NewBuffer(make([]byte, 0, int(batchEntityPreparedInputSize.Load()))), - itemData: bytes.NewBuffer(make([]byte, 0, int(batchEntityItemDataSize.Load()))), itemInput: bytes.NewBuffer(make([]byte, 0, int(batchEntityItemInputSize.Load()))), keyGen: xxhash.New(), } @@ -1176,16 +1074,14 @@ func acquireBatchEntityFetchBuffer() *batchEntityFetchBuffer { func releaseBatchEntityFetchBuffer(buf *batchEntityFetchBuffer) { batchEntityPreparedInputSize.Store(int32(buf.preparedInput.Cap())) - batchEntityItemDataSize.Store(int32(buf.itemData.Cap())) batchEntityItemInputSize.Store(int32(buf.itemInput.Cap())) buf.preparedInput.Reset() - buf.itemData.Reset() buf.itemInput.Reset() buf.keyGen.Reset() batchEntityFetchPool.Put(buf) } -func (l *Loader) loadBatchEntityFetch(ctx context.Context, fetch *BatchEntityFetch, items []int, res *result) error { +func (l *Loader) loadBatchEntityFetch(ctx context.Context, fetch *BatchEntityFetch, items []*fastjson.Value, res *result) error { res.init(fetch.PostProcessing, fetch.Info) buf := acquireBatchEntityFetchBuffer() @@ -1195,10 +1091,7 @@ func (l *Loader) loadBatchEntityFetch(ctx context.Context, fetch *BatchEntityFet fetch.Trace = &DataSourceLoadTrace{} if !l.ctx.TracingOptions.ExcludeRawInputData { buf := &bytes.Buffer{} - err := l.itemsData(items, buf) - if err != nil { - return errors.WithStack(err) - } + l.itemsData(items, buf) fetch.Trace.RawInputData = buf.Bytes() } } @@ -1213,17 +1106,14 @@ func (l *Loader) loadBatchEntityFetch(ctx context.Context, fetch *BatchEntityFet itemHashes := make([]uint64, 0, len(items)*len(fetch.Input.Items)) batchItemIndex := 0 addSeparator := false + itemData := make([]byte, 0, 1024) WithNextItem: for i, item := range items { - buf.itemData.Reset() - err = l.data.PrintNode(l.data.Nodes[item], buf.itemData) - if err != nil { - return errors.WithStack(err) - } + itemData = item.MarshalTo(itemData[:0]) for j := range fetch.Input.Items { buf.itemInput.Reset() - err = fetch.Input.Items[j].Render(l.ctx, buf.itemData.Bytes(), buf.itemInput) + err = fetch.Input.Items[j].Render(l.ctx, itemData, buf.itemInput) if err != nil { if fetch.Input.SkipErrItems { err = nil // nolint:ineffassign diff --git a/v2/pkg/engine/resolve/loader_test.go b/v2/pkg/engine/resolve/loader_test.go index d12b42934..9d61f5384 100644 --- a/v2/pkg/engine/resolve/loader_test.go +++ b/v2/pkg/engine/resolve/loader_test.go @@ -1,7 +1,6 @@ package resolve import ( - "bytes" "context" "encoding/json" "net/http" @@ -10,7 +9,7 @@ import ( "github.com/golang/mock/gomock" "github.com/stretchr/testify/assert" "github.com/wundergraph/graphql-go-tools/v2/pkg/ast" - "github.com/wundergraph/graphql-go-tools/v2/pkg/astjson" + "github.com/wundergraph/graphql-go-tools/v2/pkg/fastjsonext" ) func TestLoader_LoadGraphQLResponseData(t *testing.T) { @@ -285,20 +284,17 @@ func TestLoader_LoadGraphQLResponseData(t *testing.T) { ctx := &Context{ ctx: context.Background(), } - resolvable := &Resolvable{ - storage: &astjson.JSON{}, - } + resolvable := &Resolvable{} loader := &Loader{} err := resolvable.Init(ctx, nil, ast.OperationTypeQuery) assert.NoError(t, err) err = loader.LoadGraphQLResponseData(ctx, response, resolvable) assert.NoError(t, err) ctrl.Finish() - out := &bytes.Buffer{} - err = resolvable.storage.PrintNode(resolvable.storage.Nodes[resolvable.storage.RootNode], out) + out := fastjsonext.PrintGraphQLResponse(resolvable.data, resolvable.errors) assert.NoError(t, err) expected := `{"errors":[],"data":{"topProducts":[{"name":"Table","__typename":"Product","upc":"1","reviews":[{"body":"Love Table!","author":{"__typename":"User","id":"1","name":"user-1"}},{"body":"Prefer other Table.","author":{"__typename":"User","id":"2","name":"user-2"}}],"stock":8},{"name":"Couch","__typename":"Product","upc":"2","reviews":[{"body":"Couch Too expensive.","author":{"__typename":"User","id":"1","name":"user-1"}}],"stock":2},{"name":"Chair","__typename":"Product","upc":"3","reviews":[{"body":"Chair Could be better.","author":{"__typename":"User","id":"2","name":"user-2"}}],"stock":5}]}}` - assert.Equal(t, expected, out.String()) + assert.Equal(t, expected, out) } func TestLoader_LoadGraphQLResponseDataWithExtensions(t *testing.T) { @@ -574,20 +570,17 @@ func TestLoader_LoadGraphQLResponseDataWithExtensions(t *testing.T) { ctx: context.Background(), Extensions: []byte(`{"foo":"bar"}`), } - resolvable := &Resolvable{ - storage: &astjson.JSON{}, - } + resolvable := &Resolvable{} loader := &Loader{} err := resolvable.Init(ctx, nil, ast.OperationTypeQuery) assert.NoError(t, err) err = loader.LoadGraphQLResponseData(ctx, response, resolvable) assert.NoError(t, err) ctrl.Finish() - out := &bytes.Buffer{} - err = resolvable.storage.PrintNode(resolvable.storage.Nodes[resolvable.storage.RootNode], out) + out := fastjsonext.PrintGraphQLResponse(resolvable.data, resolvable.errors) assert.NoError(t, err) expected := `{"errors":[],"data":{"topProducts":[{"name":"Table","__typename":"Product","upc":"1","reviews":[{"body":"Love Table!","author":{"__typename":"User","id":"1","name":"user-1"}},{"body":"Prefer other Table.","author":{"__typename":"User","id":"2","name":"user-2"}}],"stock":8},{"name":"Couch","__typename":"Product","upc":"2","reviews":[{"body":"Couch Too expensive.","author":{"__typename":"User","id":"1","name":"user-1"}}],"stock":2},{"name":"Chair","__typename":"Product","upc":"3","reviews":[{"body":"Chair Could be better.","author":{"__typename":"User","id":"2","name":"user-2"}}],"stock":5}]}}` - assert.Equal(t, expected, out.String()) + assert.Equal(t, expected, out) } func BenchmarkLoader_LoadGraphQLResponseData(b *testing.B) { @@ -852,17 +845,13 @@ func BenchmarkLoader_LoadGraphQLResponseData(b *testing.B) { ctx := &Context{ ctx: context.Background(), } - resolvable := &Resolvable{ - storage: &astjson.JSON{}, - } + resolvable := &Resolvable{} loader := &Loader{} - expected := []byte(`{"errors":[],"data":{"topProducts":[{"name":"Table","__typename":"Product","upc":"1","reviews":[{"body":"Love Table!","author":{"__typename":"User","id":"1","name":"user-1"}},{"body":"Prefer other Table.","author":{"__typename":"User","id":"2","name":"user-2"}}],"stock":8},{"name":"Couch","__typename":"Product","upc":"2","reviews":[{"body":"Couch Too expensive.","author":{"__typename":"User","id":"1","name":"user-1"}}],"stock":2},{"name":"Chair","__typename":"Product","upc":"3","reviews":[{"body":"Chair Could be better.","author":{"__typename":"User","id":"2","name":"user-2"}}],"stock":5}]}}`) - out := &bytes.Buffer{} + expected := `{"errors":[],"data":{"topProducts":[{"name":"Table","__typename":"Product","upc":"1","reviews":[{"body":"Love Table!","author":{"__typename":"User","id":"1","name":"user-1"}},{"body":"Prefer other Table.","author":{"__typename":"User","id":"2","name":"user-2"}}],"stock":8},{"name":"Couch","__typename":"Product","upc":"2","reviews":[{"body":"Couch Too expensive.","author":{"__typename":"User","id":"1","name":"user-1"}}],"stock":2},{"name":"Chair","__typename":"Product","upc":"3","reviews":[{"body":"Chair Could be better.","author":{"__typename":"User","id":"2","name":"user-2"}}],"stock":5}]}}` b.SetBytes(int64(len(expected))) b.ReportAllocs() b.ResetTimer() for i := 0; i < b.N; i++ { - out.Reset() loader.Free() resolvable.Reset() err := resolvable.Init(ctx, nil, ast.OperationTypeQuery) @@ -873,12 +862,9 @@ func BenchmarkLoader_LoadGraphQLResponseData(b *testing.B) { if err != nil { b.Fatal(err) } - err = resolvable.storage.PrintNode(resolvable.storage.Nodes[resolvable.storage.RootNode], out) - if err != nil { - b.Fatal(err) - } - if !bytes.Equal(expected, out.Bytes()) { - b.Fatal("not equal") + out := fastjsonext.PrintGraphQLResponse(resolvable.data, resolvable.errors) + if expected != out { + b.Fatalf("expected %s, got %s", expected, out) } } } diff --git a/v2/pkg/engine/resolve/resolvable.go b/v2/pkg/engine/resolve/resolvable.go index b5fa72aaa..dee7655f3 100644 --- a/v2/pkg/engine/resolve/resolvable.go +++ b/v2/pkg/engine/resolve/resolvable.go @@ -3,30 +3,35 @@ package resolve import ( "bytes" "context" - "encoding/json" goerrors "errors" "fmt" "io" + "github.com/goccy/go-json" + "github.com/cespare/xxhash/v2" "github.com/pkg/errors" "github.com/tidwall/gjson" + "github.com/valyala/fastjson" + "github.com/wundergraph/graphql-go-tools/v2/pkg/fastjsonext" "github.com/wundergraph/graphql-go-tools/v2/pkg/ast" - "github.com/wundergraph/graphql-go-tools/v2/pkg/astjson" "github.com/wundergraph/graphql-go-tools/v2/pkg/internal/unsafebytes" "github.com/wundergraph/graphql-go-tools/v2/pkg/pool" ) type Resolvable struct { - storage *astjson.JSON - dataRoot int - errorsRoot int - variablesRoot int + data *fastjson.Value + errors *fastjson.Value + variables *fastjson.Value + skipAddingNullErrors bool + + parsers []*fastjson.Parser + print bool out io.Writer printErr error - path []astjson.PathElement + path []fastjsonext.PathElement depth int operationType ast.OperationType renameTypeNames []RenameTypeName @@ -36,32 +41,52 @@ type Resolvable struct { authorizationAllow map[uint64]struct{} authorizationDeny map[uint64]string - authorizationBuf *bytes.Buffer - authorizationBufObjectRef int - wroteErrors bool wroteData bool typeNames [][]byte + + // maxSize is the sum of all responses to get the possible maximum size of the response + maxSize int + + marshalBuf []byte } func NewResolvable() *Resolvable { return &Resolvable{ - storage: &astjson.JSON{}, xxh: xxhash.New(), authorizationAllow: make(map[uint64]struct{}), authorizationDeny: make(map[uint64]string), } } +var ( + parsers = &fastjson.ParserPool{} +) + +func (r *Resolvable) parseJSON(data []byte) (*fastjson.Value, error) { + parser := parsers.Get() + r.parsers = append(r.parsers, parser) + return parser.ParseBytes(data) +} + +func (r *Resolvable) MaxSize() int { + return r.maxSize +} + func (r *Resolvable) Reset() { - r.storage.Reset() + for i := range r.parsers { + parsers.Put(r.parsers[i]) + r.parsers[i] = nil + } + r.maxSize = 0 + r.parsers = r.parsers[:0] r.typeNames = r.typeNames[:0] r.wroteErrors = false r.wroteData = false - r.dataRoot = -1 - r.errorsRoot = -1 - r.variablesRoot = -1 + r.data = nil + r.errors = nil + r.variables = nil r.depth = 0 r.print = false r.out = nil @@ -71,7 +96,6 @@ func (r *Resolvable) Reset() { r.renameTypeNames = r.renameTypeNames[:0] r.authorizationError = nil r.xxh.Reset() - r.authorizationBufObjectRef = -1 for k := range r.authorizationAllow { delete(r.authorizationAllow, k) } @@ -84,12 +108,14 @@ func (r *Resolvable) Init(ctx *Context, initialData []byte, operationType ast.Op r.ctx = ctx r.operationType = operationType r.renameTypeNames = ctx.RenameTypeNames - r.dataRoot, r.errorsRoot, err = r.storage.InitResolvable(initialData) - if err != nil { - return - } + r.data = fastjson.MustParse(`{}`) + r.errors = fastjson.MustParse(`[]`) if len(ctx.Variables) != 0 { - r.variablesRoot, err = r.storage.AppendAnyJSONBytes(ctx.Variables) + r.variables = fastjson.MustParseBytes(ctx.Variables) + } + if initialData != nil { + initialValue := fastjson.MustParseBytes(initialData) + r.data, _ = fastjsonext.MergeValues(r.data, initialValue) } return } @@ -99,26 +125,33 @@ func (r *Resolvable) InitSubscription(ctx *Context, initialData []byte, postProc r.operationType = ast.OperationTypeSubscription r.renameTypeNames = ctx.RenameTypeNames if len(ctx.Variables) != 0 { - r.variablesRoot, err = r.storage.AppendObject(ctx.Variables) + r.variables = fastjson.MustParseBytes(ctx.Variables) + } + if initialData != nil { + initialValue, err := fastjson.ParseBytes(initialData) if err != nil { - return + return err + } + if postProcessing.SelectResponseDataPath == nil { + r.data, _ = fastjsonext.MergeValuesWithPath(r.data, initialValue, postProcessing.MergePath...) + } else { + selectedInitialValue := initialValue.Get(postProcessing.SelectResponseDataPath...) + if selectedInitialValue != nil { + r.data, _ = fastjsonext.MergeValuesWithPath(r.data, selectedInitialValue, postProcessing.MergePath...) + } + } + if postProcessing.SelectResponseErrorsPath != nil { + selectedInitialErrors := initialValue.Get(postProcessing.SelectResponseErrorsPath...) + if selectedInitialErrors != nil { + r.errors = selectedInitialErrors + } } } - r.dataRoot, r.errorsRoot, err = r.storage.InitResolvable(nil) - if err != nil { - return - } - raw, err := r.storage.AppendObject(initialData) - if err != nil { - return err - } - data := r.storage.Get(raw, postProcessing.SelectResponseDataPath) - if r.storage.NodeIsDefined(data) { - r.storage.MergeNodesWithPath(r.dataRoot, data, postProcessing.MergePath) + if r.data == nil { + r.data = fastjson.MustParse(`{}`) } - errs := r.storage.Get(raw, postProcessing.SelectResponseErrorsPath) - if r.storage.NodeIsDefined(errs) { - r.storage.MergeArrays(r.errorsRoot, errs) + if r.errors == nil { + r.errors = fastjson.MustParse(`[]`) } return } @@ -128,12 +161,9 @@ func (r *Resolvable) Resolve(ctx context.Context, rootData *Object, fetchTree *O r.print = false r.printErr = nil r.authorizationError = nil + r.skipAddingNullErrors = r.hasErrors() && !r.hasData() - /* @TODO: In the event of an error or failed fetch, propagate only the highest level errors. - * For example, if a fetch fails, only propagate that the fetch has failed; do not propagate nested non-null errors. - */ - - err := r.walkObject(rootData, r.dataRoot) + hasErrors := r.walkObject(rootData, r.data) if r.authorizationError != nil { return r.authorizationError } @@ -142,7 +172,7 @@ func (r *Resolvable) Resolve(ctx context.Context, rootData *Object, fetchTree *O r.printErrors() } - if err { + if hasErrors { r.printBytes(quote) r.printBytes(literalData) r.printBytes(quote) @@ -156,7 +186,6 @@ func (r *Resolvable) Resolve(ctx context.Context, rootData *Object, fetchTree *O r.printErr = r.printExtensions(ctx, fetchTree) } r.printBytes(rBrace) - return r.printErr } @@ -169,7 +198,7 @@ func (r *Resolvable) printErrors() { r.printBytes(literalErrors) r.printBytes(quote) r.printBytes(colon) - r.printNode(r.errorsRoot) + r.printNode(r.errors) r.printBytes(comma) r.wroteErrors = true } @@ -181,7 +210,7 @@ func (r *Resolvable) printData(root *Object) { r.printBytes(colon) r.printBytes(lBrace) r.print = true - _ = r.walkObject(root, r.dataRoot) + _ = r.walkObject(root, r.data) r.print = false r.printBytes(rBrace) r.wroteData = true @@ -284,18 +313,25 @@ func (r *Resolvable) WroteErrorsWithoutData() bool { } func (r *Resolvable) hasErrors() bool { - return r.storage.NodeIsDefined(r.errorsRoot) && - len(r.storage.Nodes[r.errorsRoot].ArrayValues) > 0 + if r.errors == nil { + return false + } + values, err := r.errors.Array() + if err != nil { + return false + } + return len(values) > 0 } func (r *Resolvable) hasData() bool { - if !r.storage.NodeIsDefined(r.dataRoot) { + if r.data == nil { return false } - if r.storage.Nodes[r.dataRoot].Kind != astjson.NodeKindObject { + obj, err := r.data.Object() + if err != nil { return false } - return len(r.storage.Nodes[r.dataRoot].ObjectFields) > 0 + return obj.Len() > 0 } func (r *Resolvable) printBytes(b []byte) { @@ -305,16 +341,17 @@ func (r *Resolvable) printBytes(b []byte) { _, r.printErr = r.out.Write(b) } -func (r *Resolvable) printNode(ref int) { +func (r *Resolvable) printNode(value *fastjson.Value) { if r.printErr != nil { return } - r.printErr = r.storage.PrintNode(r.storage.Nodes[ref], r.out) + r.marshalBuf = value.MarshalTo(r.marshalBuf[:0]) + _, r.printErr = r.out.Write(r.marshalBuf) } func (r *Resolvable) pushArrayPathElement(index int) { - r.path = append(r.path, astjson.PathElement{ - ArrayIndex: index, + r.path = append(r.path, fastjsonext.PathElement{ + Idx: index, }) } @@ -325,7 +362,7 @@ func (r *Resolvable) popArrayPathElement() { func (r *Resolvable) pushNodePathElement(path []string) { r.depth++ for i := range path { - r.path = append(r.path, astjson.PathElement{ + r.path = append(r.path, fastjsonext.PathElement{ Name: path[i], }) } @@ -336,7 +373,7 @@ func (r *Resolvable) popNodePathElement(path []string) { r.depth-- } -func (r *Resolvable) walkNode(node Node, ref int) bool { +func (r *Resolvable) walkNode(node Node, value *fastjson.Value) bool { if r.authorizationError != nil { return true } @@ -345,51 +382,47 @@ func (r *Resolvable) walkNode(node Node, ref int) bool { } switch n := node.(type) { case *Object: - return r.walkObject(n, ref) + return r.walkObject(n, value) case *Array: - return r.walkArray(n, ref) + return r.walkArray(n, value) case *Null: return r.walkNull() case *String: - return r.walkString(n, ref) + return r.walkString(n, value) case *Boolean: - return r.walkBoolean(n, ref) + return r.walkBoolean(n, value) case *Integer: - return r.walkInteger(n, ref) + return r.walkInteger(n, value) case *Float: - return r.walkFloat(n, ref) + return r.walkFloat(n, value) case *BigInt: - return r.walkBigInt(n, ref) + return r.walkBigInt(n, value) case *Scalar: - return r.walkScalar(n, ref) + return r.walkScalar(n, value) case *EmptyObject: return r.walkEmptyObject(n) case *EmptyArray: return r.walkEmptyArray(n) case *CustomNode: - return r.walkCustom(n, ref) + return r.walkCustom(n, value) default: return false } } -func (r *Resolvable) walkObject(obj *Object, ref int) bool { - ref = r.storage.Get(ref, obj.Path) - if !r.storage.NodeIsDefined(ref) { +func (r *Resolvable) walkObject(obj *Object, parent *fastjson.Value) bool { + value := parent.Get(obj.Path...) + if value == nil || value.Type() == fastjson.TypeNull { if obj.Nullable { return r.walkNull() } - r.addNonNullableFieldError(ref, obj.Path) + r.addNonNullableFieldError(obj.Path, parent) return r.err() } r.pushNodePathElement(obj.Path) isRoot := r.depth < 2 defer r.popNodePathElement(obj.Path) - - if r.storage.Nodes[ref].Kind == astjson.NodeKindNull { - return r.walkNull() - } - if r.storage.Nodes[ref].Kind != astjson.NodeKindObject { + if value.Type() != fastjson.TypeObject { r.addError("Object cannot represent non-object value.", obj.Path) return r.err() } @@ -398,7 +431,7 @@ func (r *Resolvable) walkObject(obj *Object, ref int) bool { r.ctx.Stats.ResolvedObjects++ } addComma := false - typeName := r.getObjectTypeName(ref) + typeName := value.GetStringBytes("__typename") r.typeNames = append(r.typeNames, typeName) defer func() { r.typeNames = r.typeNames[:len(r.typeNames)-1] @@ -425,19 +458,20 @@ func (r *Resolvable) walkObject(obj *Object, ref int) bool { } } if !r.print { - skip := r.authorizeField(ref, obj.Fields[i]) + skip := r.authorizeField(value, obj.Fields[i]) if skip { if obj.Fields[i].Value.NodeNullable() { // if the field value is nullable, we can just set it to null // we already set an error in authorizeField - field := r.storage.Get(ref, obj.Fields[i].Value.NodePath()) - if r.storage.NodeIsDefined(field) { - r.storage.Nodes[field].Kind = astjson.NodeKindNull + path := obj.Fields[i].Value.NodePath() + field := value.Get(path...) + if field != nil { + fastjsonext.SetNull(value, path...) } } else if obj.Nullable { // if the field value is not nullable, but the object is nullable // we can just set the whole object to null - r.storage.Nodes[ref].Kind = astjson.NodeKindNull + fastjsonext.SetNull(parent, obj.Path...) } else { // if the field value is not nullable and the object is not nullable // we return true to indicate an error @@ -455,11 +489,13 @@ func (r *Resolvable) walkObject(obj *Object, ref int) bool { r.printBytes(quote) r.printBytes(colon) } - err := r.walkNode(obj.Fields[i].Value, ref) + err := r.walkNode(obj.Fields[i].Value, value) if err { if obj.Nullable { - r.storage.Nodes[ref].Kind = astjson.NodeKindNull - return false + if len(obj.Path) > 0 { + fastjsonext.SetNull(parent, obj.Path...) + return false + } } return err } @@ -471,7 +507,7 @@ func (r *Resolvable) walkObject(obj *Object, ref int) bool { return false } -func (r *Resolvable) authorizeField(ref int, field *Field) (skipField bool) { +func (r *Resolvable) authorizeField(value *fastjson.Value, field *Field) (skipField bool) { if field.Info == nil { return false } @@ -485,13 +521,13 @@ func (r *Resolvable) authorizeField(ref int, field *Field) (skipField bool) { return false } dataSourceID := field.Info.Source.IDs[0] - typeName := r.objectFieldTypeName(ref, field) + typeName := r.objectFieldTypeName(value, field) fieldName := unsafebytes.BytesToString(field.Name) gc := GraphCoordinate{ TypeName: typeName, FieldName: fieldName, } - result, authErr := r.authorize(ref, dataSourceID, gc) + result, authErr := r.authorize(value, dataSourceID, gc) if authErr != nil { r.authorizationError = authErr return true @@ -503,7 +539,7 @@ func (r *Resolvable) authorizeField(ref int, field *Field) (skipField bool) { return false } -func (r *Resolvable) authorize(objectRef int, dataSourceID string, coordinate GraphCoordinate) (result *AuthorizationDeny, err error) { +func (r *Resolvable) authorize(value *fastjson.Value, dataSourceID string, coordinate GraphCoordinate) (result *AuthorizationDeny, err error) { r.xxh.Reset() _, _ = r.xxh.WriteString(dataSourceID) _, _ = r.xxh.WriteString(coordinate.TypeName) @@ -515,18 +551,8 @@ func (r *Resolvable) authorize(objectRef int, dataSourceID string, coordinate Gr if reason, ok := r.authorizationDeny[decisionID]; ok { return &AuthorizationDeny{Reason: reason}, nil } - if r.authorizationBufObjectRef != objectRef { - if r.authorizationBuf == nil { - r.authorizationBuf = bytes.NewBuffer(nil) - } - r.authorizationBuf.Reset() - err = r.storage.PrintObjectFlat(objectRef, r.authorizationBuf) - if err != nil { - return nil, err - } - r.authorizationBufObjectRef = objectRef - } - result, err = r.ctx.authorizer.AuthorizeObjectField(r.ctx, dataSourceID, r.authorizationBuf.Bytes(), coordinate) + r.marshalBuf = value.MarshalTo(r.marshalBuf[:0]) + result, err = r.ctx.authorizer.AuthorizeObjectField(r.ctx, dataSourceID, r.marshalBuf, coordinate) if err != nil { return nil, err } @@ -550,29 +576,18 @@ func (r *Resolvable) addRejectFieldError(reason, dataSourceID string, field *Fie errorMessage = fmt.Sprintf("Unauthorized to load field '%s', Reason: %s.", fieldPath, reason) } r.ctx.appendSubgraphError(goerrors.Join(errors.New(errorMessage), NewSubgraphError(dataSourceID, fieldPath, reason, 0))) - - ref := r.storage.AppendErrorWithMessage(errorMessage, r.path) - r.storage.Nodes[r.errorsRoot].ArrayValues = append(r.storage.Nodes[r.errorsRoot].ArrayValues, ref) + fastjsonext.AppendErrorToArray(r.errors, errorMessage, r.path) r.popNodePathElement(nodePath) } -func (r *Resolvable) objectFieldTypeName(ref int, field *Field) string { - typeName := r.storage.GetObjectField(ref, "__typename") - if r.storage.NodeIsDefined(typeName) && r.storage.Nodes[typeName].Kind == astjson.NodeKindString { - name := r.storage.Nodes[typeName].ValueBytes(r.storage) - return unsafebytes.BytesToString(name) +func (r *Resolvable) objectFieldTypeName(v *fastjson.Value, field *Field) string { + typeName := v.GetStringBytes("__typename") + if typeName != nil { + return unsafebytes.BytesToString(typeName) } return field.Info.ExactParentTypeName } -func (r *Resolvable) getObjectTypeName(ref int) []byte { - typeName := r.storage.GetObjectField(ref, "__typename") - if r.storage.NodeIsDefined(typeName) && r.storage.Nodes[typeName].Kind == astjson.NodeKindString { - return r.storage.Nodes[typeName].ValueBytes(r.storage) - } - return nil -} - func (r *Resolvable) skipFieldOnParentTypeNames(field *Field) bool { WithNext: for i := range field.ParentOnTypeNames { @@ -612,57 +627,55 @@ func (r *Resolvable) skipFieldOnTypeNames(field *Field) bool { } func (r *Resolvable) skipField(skipVariableName string) bool { - field := r.storage.GetObjectField(r.variablesRoot, skipVariableName) - if !r.storage.NodeIsDefined(field) { - return false - } - if r.storage.Nodes[field].Kind != astjson.NodeKindBoolean { + variable := r.variables.Get(skipVariableName) + if variable == nil { return false } - value := r.storage.Nodes[field].ValueBytes(r.storage) - return bytes.Equal(value, literalTrue) + return variable.Type() == fastjson.TypeTrue } func (r *Resolvable) excludeField(includeVariableName string) bool { - field := r.storage.GetObjectField(r.variablesRoot, includeVariableName) - if !r.storage.NodeIsDefined(field) { + variable := r.variables.Get(includeVariableName) + if variable == nil { return true } - if r.storage.Nodes[field].Kind != astjson.NodeKindBoolean { - return true - } - value := r.storage.Nodes[field].ValueBytes(r.storage) - return bytes.Equal(value, literalFalse) + return variable.Type() == fastjson.TypeFalse } -func (r *Resolvable) walkArray(arr *Array, ref int) bool { - ref = r.storage.Get(ref, arr.Path) - if !r.storage.NodeIsDefined(ref) { +func (r *Resolvable) walkArray(arr *Array, value *fastjson.Value) bool { + parent := value + value = value.Get(arr.Path...) + if fastjsonext.ValueIsNull(value) { if arr.Nullable { return r.walkNull() } - r.addNonNullableFieldError(ref, arr.Path) + r.addNonNullableFieldError(arr.Path, parent) return r.err() } r.pushNodePathElement(arr.Path) defer r.popNodePathElement(arr.Path) - if r.storage.Nodes[ref].Kind != astjson.NodeKindArray { + if value.Type() != fastjson.TypeArray { r.addError("Array cannot represent non-array value.", arr.Path) return r.err() } if r.print { r.printBytes(lBrack) } - for i, value := range r.storage.Nodes[ref].ArrayValues { + values := value.GetArray() + for i, arrayValue := range values { if r.print && i != 0 { r.printBytes(comma) } r.pushArrayPathElement(i) - err := r.walkNode(arr.Item, value) + err := r.walkNode(arr.Item, arrayValue) r.popArrayPathElement() if err { + if arr.Item.NodeKind() == NodeKindObject && arr.Item.NodeNullable() { + value.SetArrayItem(i, fastjsonext.NullValue) + continue + } if arr.Nullable { - r.storage.Nodes[ref].Kind = astjson.NodeKindNull + fastjsonext.SetNull(parent, arr.Path...) return false } return err @@ -682,155 +695,161 @@ func (r *Resolvable) walkNull() bool { return false } -func (r *Resolvable) walkString(s *String, ref int) bool { +func (r *Resolvable) walkString(s *String, value *fastjson.Value) bool { if r.print { r.ctx.Stats.ResolvedLeafs++ } - ref = r.storage.Get(ref, s.Path) - if !r.storage.NodeIsDefined(ref) { + parent := value + value = value.Get(s.Path...) + if fastjsonext.ValueIsNull(value) { if s.Nullable { return r.walkNull() } - r.addNonNullableFieldError(ref, s.Path) + r.addNonNullableFieldError(s.Path, parent) return r.err() } - if r.storage.Nodes[ref].Kind != astjson.NodeKindString { - value := string(r.storage.Nodes[ref].ValueBytes(r.storage)) - r.addError(fmt.Sprintf("String cannot represent non-string value: \\\"%s\\\"", value), s.Path) + if value.Type() != fastjson.TypeString { + r.marshalBuf = value.MarshalTo(r.marshalBuf[:0]) + r.addError(fmt.Sprintf("String cannot represent non-string value: \\\"%s\\\"", string(r.marshalBuf)), s.Path) return r.err() } if r.print { if s.IsTypeName { - value := r.storage.Nodes[ref].ValueBytes(r.storage) + content := value.GetStringBytes() for i := range r.renameTypeNames { - if bytes.Equal(value, r.renameTypeNames[i].From) { + if bytes.Equal(content, r.renameTypeNames[i].From) { r.printBytes(quote) r.printBytes(r.renameTypeNames[i].To) r.printBytes(quote) return false } } - r.printNode(ref) + r.printNode(value) return false } if s.UnescapeResponseJson { - value := r.storage.Nodes[ref].ValueBytes(r.storage) - value = bytes.ReplaceAll(value, []byte(`\"`), []byte(`"`)) - if !gjson.ValidBytes(value) { + content := value.GetStringBytes() + content = bytes.ReplaceAll(content, []byte(`\"`), []byte(`"`)) + if !gjson.ValidBytes(content) { r.printBytes(quote) - r.printBytes(value) + r.printBytes(content) r.printBytes(quote) } else { - r.printBytes(value) + r.printBytes(content) } } else { - r.printNode(ref) + r.printNode(value) } } return false } -func (r *Resolvable) walkBoolean(b *Boolean, ref int) bool { +func (r *Resolvable) walkBoolean(b *Boolean, value *fastjson.Value) bool { if r.print { r.ctx.Stats.ResolvedLeafs++ } - ref = r.storage.Get(ref, b.Path) - if !r.storage.NodeIsDefined(ref) { + parent := value + value = value.Get(b.Path...) + if fastjsonext.ValueIsNull(value) { if b.Nullable { return r.walkNull() } - r.addNonNullableFieldError(ref, b.Path) + r.addNonNullableFieldError(b.Path, parent) return r.err() } - if r.storage.Nodes[ref].Kind != astjson.NodeKindBoolean { - value := string(r.storage.Nodes[ref].ValueBytes(r.storage)) - r.addError(fmt.Sprintf("Bool cannot represent non-boolean value: \\\"%s\\\"", value), b.Path) + if value.Type() != fastjson.TypeTrue && value.Type() != fastjson.TypeFalse { + r.marshalBuf = value.MarshalTo(r.marshalBuf[:0]) + r.addError(fmt.Sprintf("Bool cannot represent non-boolean value: \\\"%s\\\"", string(r.marshalBuf)), b.Path) return r.err() } if r.print { - r.printNode(ref) + r.printNode(value) } return false } -func (r *Resolvable) walkInteger(i *Integer, ref int) bool { +func (r *Resolvable) walkInteger(i *Integer, value *fastjson.Value) bool { if r.print { r.ctx.Stats.ResolvedLeafs++ } - ref = r.storage.Get(ref, i.Path) - if !r.storage.NodeIsDefined(ref) { + parent := value + value = value.Get(i.Path...) + if fastjsonext.ValueIsNull(value) { if i.Nullable { return r.walkNull() } - r.addNonNullableFieldError(ref, i.Path) + r.addNonNullableFieldError(i.Path, parent) return r.err() } - if r.storage.Nodes[ref].Kind != astjson.NodeKindNumber { - value := string(r.storage.Nodes[ref].ValueBytes(r.storage)) - r.addError(fmt.Sprintf("Int cannot represent non-integer value: \\\"%s\\\"", value), i.Path) + if value.Type() != fastjson.TypeNumber { + r.marshalBuf = value.MarshalTo(r.marshalBuf[:0]) + r.addError(fmt.Sprintf("Int cannot represent non-integer value: \\\"%s\\\"", string(r.marshalBuf)), i.Path) return r.err() } if r.print { - r.printNode(ref) + r.printNode(value) } return false } -func (r *Resolvable) walkFloat(f *Float, ref int) bool { +func (r *Resolvable) walkFloat(f *Float, value *fastjson.Value) bool { if r.print { r.ctx.Stats.ResolvedLeafs++ } - ref = r.storage.Get(ref, f.Path) - if !r.storage.NodeIsDefined(ref) { + parent := value + value = value.Get(f.Path...) + if fastjsonext.ValueIsNull(value) { if f.Nullable { return r.walkNull() } - r.addNonNullableFieldError(ref, f.Path) + r.addNonNullableFieldError(f.Path, parent) return r.err() } - if r.storage.Nodes[ref].Kind != astjson.NodeKindNumber { - value := string(r.storage.Nodes[ref].ValueBytes(r.storage)) - r.addError(fmt.Sprintf("Float cannot represent non-float value: \\\"%s\\\"", value), f.Path) + if value.Type() != fastjson.TypeNumber { + r.marshalBuf = value.MarshalTo(r.marshalBuf[:0]) + r.addError(fmt.Sprintf("Float cannot represent non-float value: \\\"%s\\\"", string(r.marshalBuf)), f.Path) return r.err() } if r.print { - r.printNode(ref) + r.printNode(value) } return false } -func (r *Resolvable) walkBigInt(b *BigInt, ref int) bool { +func (r *Resolvable) walkBigInt(b *BigInt, value *fastjson.Value) bool { if r.print { r.ctx.Stats.ResolvedLeafs++ } - ref = r.storage.Get(ref, b.Path) - if !r.storage.NodeIsDefined(ref) { + parent := value + value = value.Get(b.Path...) + if fastjsonext.ValueIsNull(value) { if b.Nullable { return r.walkNull() } - r.addNonNullableFieldError(ref, b.Path) + r.addNonNullableFieldError(b.Path, parent) return r.err() } if r.print { - r.printNode(ref) + r.printNode(value) } return false } -func (r *Resolvable) walkScalar(s *Scalar, ref int) bool { +func (r *Resolvable) walkScalar(s *Scalar, value *fastjson.Value) bool { if r.print { r.ctx.Stats.ResolvedLeafs++ } - ref = r.storage.Get(ref, s.Path) - if !r.storage.NodeIsDefined(ref) { + parent := value + value = value.Get(s.Path...) + if fastjsonext.ValueIsNull(value) { if s.Nullable { return r.walkNull() } - r.addNonNullableFieldError(ref, s.Path) + r.addNonNullableFieldError(s.Path, parent) return r.err() } if r.print { - r.printNode(ref) + r.printNode(value) } return false } @@ -851,20 +870,21 @@ func (r *Resolvable) walkEmptyArray(_ *EmptyArray) bool { return false } -func (r *Resolvable) walkCustom(c *CustomNode, ref int) bool { +func (r *Resolvable) walkCustom(c *CustomNode, value *fastjson.Value) bool { if r.print { r.ctx.Stats.ResolvedLeafs++ } - ref = r.storage.Get(ref, c.Path) - if !r.storage.NodeIsDefined(ref) { + parent := value + value = value.Get(c.Path...) + if fastjsonext.ValueIsNull(value) { if c.Nullable { return r.walkNull() } - r.addNonNullableFieldError(ref, c.Path) + r.addNonNullableFieldError(c.Path, parent) return r.err() } - value := r.storage.Nodes[ref].ValueBytes(r.storage) - resolved, err := c.Resolve(r.ctx, value) + r.marshalBuf = value.MarshalTo(r.marshalBuf[:0]) + resolved, err := c.Resolve(r.ctx, r.marshalBuf) if err != nil { r.addError(err.Error(), c.Path) return r.err() @@ -875,13 +895,20 @@ func (r *Resolvable) walkCustom(c *CustomNode, ref int) bool { return false } -func (r *Resolvable) addNonNullableFieldError(fieldRef int, fieldPath []string) { - if fieldRef != -1 && r.storage.Nodes[fieldRef].Kind == astjson.NodeKindNullSkipError { +func (r *Resolvable) addNonNullableFieldError(fieldPath []string, parent *fastjson.Value) { + if r.skipAddingNullErrors { return } + if fieldPath != nil { + if ancestor := parent.Get(fieldPath[:len(fieldPath)-1]...); ancestor != nil { + if ancestor.Exists("__skipErrors") { + return + } + } + } r.pushNodePathElement(fieldPath) - ref := r.storage.AppendNonNullableFieldIsNullErr(r.renderFieldPath(), r.path) - r.storage.Nodes[r.errorsRoot].ArrayValues = append(r.storage.Nodes[r.errorsRoot].ArrayValues, ref) + errorMessage := fmt.Sprintf("Cannot return null for non-nullable field '%s'.", r.renderFieldPath()) + fastjsonext.AppendErrorToArray(r.errors, errorMessage, r.path) r.popNodePathElement(fieldPath) } @@ -907,7 +934,6 @@ func (r *Resolvable) renderFieldPath() string { func (r *Resolvable) addError(message string, fieldPath []string) { r.pushNodePathElement(fieldPath) - ref := r.storage.AppendErrorWithMessage(message, r.path) - r.storage.Nodes[r.errorsRoot].ArrayValues = append(r.storage.Nodes[r.errorsRoot].ArrayValues, ref) + fastjsonext.AppendErrorToArray(r.errors, message, r.path) r.popNodePathElement(fieldPath) } diff --git a/v2/pkg/engine/resolve/resolve.go b/v2/pkg/engine/resolve/resolve.go index 34adaa283..dd4a690ed 100644 --- a/v2/pkg/engine/resolve/resolve.go +++ b/v2/pkg/engine/resolve/resolve.go @@ -211,7 +211,7 @@ func (r *Resolver) ResolveGraphQLResponse(ctx *Context, response *GraphQLRespons fetchTree = response.Data } - buf := r.getBuffer(t.resolvable.storage.Size()) + buf := r.getBuffer(t.resolvable.MaxSize()) defer r.releaseBuffer(buf) err = t.resolvable.Resolve(ctx.ctx, response.Data, fetchTree, buf) r.putTools(t) diff --git a/v2/pkg/engine/resolve/resolve_federation_test.go b/v2/pkg/engine/resolve/resolve_federation_test.go index c3156479d..bb718f30a 100644 --- a/v2/pkg/engine/resolve/resolve_federation_test.go +++ b/v2/pkg/engine/resolve/resolve_federation_test.go @@ -109,7 +109,7 @@ func TestResolveGraphQLResponse_Federation(t *testing.T) { ), Input: `{"method":"POST","url":"http://account.service","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations){__typename ... on Account {name shippingInfo {zip}}}}","variables":{"representations":$$0$$}}}`, PostProcessing: PostProcessingConfiguration{ - SelectResponseDataPath: []string{"data", "_entities", "[0]"}, + SelectResponseDataPath: []string{"data", "_entities", "0"}, }, }, InputTemplate: InputTemplate{ @@ -280,7 +280,7 @@ func TestResolveGraphQLResponse_Federation(t *testing.T) { SetTemplateOutputToNullOnVariableNull: true, DataSource: secondService, PostProcessing: PostProcessingConfiguration{ - SelectResponseDataPath: []string{"data", "_entities", "[0]"}, + SelectResponseDataPath: []string{"data", "_entities", "0"}, }, }, DataSourceIdentifier: []byte("graphql_datasource.Source"), @@ -322,7 +322,7 @@ func TestResolveGraphQLResponse_Federation(t *testing.T) { SetTemplateOutputToNullOnVariableNull: true, DataSource: thirdService, PostProcessing: PostProcessingConfiguration{ - SelectResponseDataPath: []string{"data", "_entities", "[0]"}, + SelectResponseDataPath: []string{"data", "_entities", "0"}, }, }, DataSourceIdentifier: []byte("graphql_datasource.Source"), diff --git a/v2/pkg/engine/resolve/resolve_test.go b/v2/pkg/engine/resolve/resolve_test.go index 79aabad74..16d0072f8 100644 --- a/v2/pkg/engine/resolve/resolve_test.go +++ b/v2/pkg/engine/resolve/resolve_test.go @@ -1438,7 +1438,7 @@ func TestResolver_ResolveNode(t *testing.T) { }, }, }, - }, Context{ctx: context.Background()}, `{"data":{"id":1}}` + }, Context{ctx: context.Background()}, `{"data":{"id":"1"}}` })) t.Run("custom nullable", testGraphQLErrFn(func(t *testing.T, r *Resolver, ctrl *gomock.Controller) (node *Object, ctx Context, expectedErr string) { return &Object{ @@ -2012,7 +2012,7 @@ func TestResolver_ResolveGraphQLResponse(t *testing.T) { }, }, }, - }, Context{ctx: context.Background()}, `{"errors":[{"message":"Failed to fetch from Subgraph 'Users' at Path 'query'."},{"message":"Cannot return null for non-nullable field 'Query.name'.","path":["name"]}],"data":null}` + }, Context{ctx: context.Background()}, `{"errors":[{"message":"Failed to fetch from Subgraph 'Users' at Path 'query'."}],"data":null}` })) t.Run("root field with nested non-nullable fields returns null", testFn(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { return &GraphQLResponse{ @@ -2475,7 +2475,7 @@ func TestResolver_ResolveGraphQLResponse(t *testing.T) { }, }, }, - }, Context{ctx: context.Background()}, `{"errors":[{"message":"Failed to fetch from Subgraph at Path 'query', Reason: no data or errors in response."},{"message":"Cannot return null for non-nullable field 'Query.nonNullArray'.","path":["nonNullArray"]}],"data":null}` + }, Context{ctx: context.Background()}, `{"errors":[{"message":"Failed to fetch from Subgraph at Path 'query', Reason: no data or errors in response."}],"data":null}` })) t.Run("when data null and errors present not nullable array should result to null data upstream error and resolve error", testFn(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { return &GraphQLResponse{ @@ -2868,7 +2868,7 @@ func TestResolver_ResolveGraphQLResponse(t *testing.T) { FetchConfiguration: FetchConfiguration{ DataSource: reviewsService, PostProcessing: PostProcessingConfiguration{ - SelectResponseDataPath: []string{"data", "_entities", "[0]"}, + SelectResponseDataPath: []string{"data", "_entities", "0"}, }, }, }, @@ -2911,7 +2911,7 @@ func TestResolver_ResolveGraphQLResponse(t *testing.T) { FetchConfiguration: FetchConfiguration{ DataSource: productService, PostProcessing: PostProcessingConfiguration{ - SelectResponseDataPath: []string{"data", "_entities", "[0]"}, + SelectResponseDataPath: []string{"data", "_entities", "0"}, }, }, InputTemplate: InputTemplate{ @@ -3070,7 +3070,7 @@ func TestResolver_ResolveGraphQLResponse(t *testing.T) { FetchConfiguration: FetchConfiguration{ DataSource: reviewsService, PostProcessing: PostProcessingConfiguration{ - SelectResponseDataPath: []string{"data", "_entities", "[0]"}, + SelectResponseDataPath: []string{"data", "_entities", "0"}, }, }, }, @@ -3271,7 +3271,7 @@ func TestResolver_ResolveGraphQLResponse(t *testing.T) { FetchConfiguration: FetchConfiguration{ DataSource: reviewsService, PostProcessing: PostProcessingConfiguration{ - SelectResponseDataPath: []string{"data", "_entities", "[0]"}, + SelectResponseDataPath: []string{"data", "_entities", "0"}, }, }, }, @@ -3481,7 +3481,7 @@ func TestResolver_ResolveGraphQLResponse(t *testing.T) { FetchConfiguration: FetchConfiguration{ DataSource: reviewsService, PostProcessing: PostProcessingConfiguration{ - SelectResponseDataPath: []string{"data", "_entities", "[0]"}, + SelectResponseDataPath: []string{"data", "_entities", "0"}, }, }, }, @@ -3691,7 +3691,7 @@ func TestResolver_ResolveGraphQLResponse(t *testing.T) { FetchConfiguration: FetchConfiguration{ DataSource: reviewsService, PostProcessing: PostProcessingConfiguration{ - SelectResponseDataPath: []string{"data", "_entities", "[0]"}, + SelectResponseDataPath: []string{"data", "_entities", "0"}, }, }, }, @@ -3881,7 +3881,7 @@ func TestResolver_ResolveGraphQLResponse(t *testing.T) { FetchConfiguration: FetchConfiguration{ DataSource: reviewsService, PostProcessing: PostProcessingConfiguration{ - SelectResponseDataPath: []string{"data", "_entities", "[0]"}, + SelectResponseDataPath: []string{"data", "_entities", "0"}, }, }, }, @@ -4138,7 +4138,7 @@ func TestResolver_ResolveGraphQLResponse(t *testing.T) { FetchConfiguration: FetchConfiguration{ DataSource: timeService, PostProcessing: PostProcessingConfiguration{ - SelectResponseDataPath: []string{"data", "_entities", "[0]"}, + SelectResponseDataPath: []string{"data", "_entities", "0"}, }, }, }, @@ -4178,7 +4178,7 @@ func TestResolver_ResolveGraphQLResponse(t *testing.T) { FetchConfiguration: FetchConfiguration{ DataSource: employeeService, PostProcessing: PostProcessingConfiguration{ - SelectResponseDataPath: []string{"data", "_entities", "[0]"}, + SelectResponseDataPath: []string{"data", "_entities", "0"}, }, }, }, @@ -4482,7 +4482,7 @@ func TestResolver_ResolveGraphQLSubscription(t *testing.T) { recorder.AwaitMessages(t, 1, defaultTimeout) recorder.AwaitComplete(t, defaultTimeout) assert.Equal(t, 1, len(recorder.Messages())) - assert.Equal(t, `{"errors":[{"message":"Validation error occurred","locations":[{"line":1,"column":1}],"extensions":{"code":"GRAPHQL_VALIDATION_FAILED"}},{"message":"Cannot return null for non-nullable field 'Subscription.counter'.","path":["counter"]}],"data":null}`, recorder.Messages()[0]) + assert.Equal(t, `{"errors":[{"message":"Validation error occurred","locations":[{"line":1,"column":1}],"extensions":{"code":"GRAPHQL_VALIDATION_FAILED"}}],"data":null}`, recorder.Messages()[0]) }) t.Run("should return an error if the data source has not been defined", func(t *testing.T) { @@ -5309,7 +5309,7 @@ func Benchmark_ResolveGraphQLResponse(b *testing.B) { { Name: []byte("age"), Value: &Integer{ - Path: []string{"[0]", "age"}, + Path: []string{"0", "age"}, }, }, }, @@ -5322,7 +5322,7 @@ func Benchmark_ResolveGraphQLResponse(b *testing.B) { { Name: []byte("line1"), Value: &String{ - Path: []string{"[1]", "line1"}, + Path: []string{"1", "line1"}, }, }, }, diff --git a/v2/pkg/fastjsonext/fastjsonext.go b/v2/pkg/fastjsonext/fastjsonext.go new file mode 100644 index 000000000..078d0b5b6 --- /dev/null +++ b/v2/pkg/fastjsonext/fastjsonext.go @@ -0,0 +1,181 @@ +package fastjsonext + +import ( + "bytes" + "fmt" + "strconv" + + "github.com/valyala/fastjson" + "github.com/wundergraph/graphql-go-tools/v2/pkg/internal/unsafebytes" +) + +var ( + NullValue = fastjson.MustParse(`null`) +) + +func MergeValues(a, b *fastjson.Value) (*fastjson.Value, bool) { + if a == nil { + return b, true + } + if b == nil { + return a, false + } + if a.Type() != b.Type() { + return a, false + } + switch a.Type() { + case fastjson.TypeObject: + ao, _ := a.Object() + bo, _ := b.Object() + ao.Visit(func(key []byte, l *fastjson.Value) { + sKey := unsafebytes.BytesToString(key) + r := bo.Get(sKey) + if r == nil { + return + } + merged, changed := MergeValues(l, r) + if changed { + ao.Set(unsafebytes.BytesToString(key), merged) + } + }) + bo.Visit(func(key []byte, r *fastjson.Value) { + sKey := unsafebytes.BytesToString(key) + if ao.Get(sKey) != nil { + return + } + ao.Set(sKey, r) + }) + return a, false + case fastjson.TypeArray: + aa, _ := a.Array() + ba, _ := b.Array() + for i := 0; i < len(ba); i++ { + a.SetArrayItem(len(aa)+i, ba[i]) + } + return a, false + case fastjson.TypeFalse: + if b.Type() == fastjson.TypeTrue { + return b, true + } + return a, false + case fastjson.TypeTrue: + if b.Type() == fastjson.TypeFalse { + return b, true + } + return a, false + case fastjson.TypeNull: + if b.Type() != fastjson.TypeNull { + return b, true + } + return a, false + case fastjson.TypeNumber: + af, _ := a.Float64() + bf, _ := b.Float64() + if af != bf { + return b, true + } + return a, false + case fastjson.TypeString: + as, _ := a.StringBytes() + bs, _ := b.StringBytes() + if !bytes.Equal(as, bs) { + return b, true + } + return a, false + default: + return b, true + } +} + +func MergeValuesWithPath(a, b *fastjson.Value, path ...string) (*fastjson.Value, bool) { + if len(path) == 0 { + return MergeValues(a, b) + } + root := fastjson.MustParseBytes([]byte(`{}`)) + current := root + for i := 0; i < len(path)-1; i++ { + current.Set(path[i], fastjson.MustParseBytes([]byte(`{}`))) + current = current.Get(path[i]) + } + current.Set(path[len(path)-1], b) + return MergeValues(a, root) +} + +func AppendToArray(array, value *fastjson.Value) { + if array.Type() != fastjson.TypeArray { + return + } + items, _ := array.Array() + array.SetArrayItem(len(items), value) +} + +func AppendErrorToArray(v *fastjson.Value, msg string, path []PathElement) { + if v.Type() != fastjson.TypeArray { + return + } + errorObject := CreateErrorObjectWithPath(msg, path) + items, _ := v.Array() + v.SetArrayItem(len(items), errorObject) +} + +func SetValue(v *fastjson.Value, value *fastjson.Value, path ...string) { + for i := 0; i < len(path)-1; i++ { + parent := v + v = v.Get(path[i]) + if v == nil { + child := fastjson.MustParse(`{}`) + parent.Set(path[i], child) + v = child + } + } + v.Set(path[len(path)-1], value) +} + +func SetNull(v *fastjson.Value, path ...string) { + SetValue(v, fastjson.MustParse(`null`), path...) +} + +func ValueIsNonNull(v *fastjson.Value) bool { + if v == nil { + return false + } + if v.Type() == fastjson.TypeNull { + return false + } + return true +} + +func ValueIsNull(v *fastjson.Value) bool { + return !ValueIsNonNull(v) +} + +type PathElement struct { + Name string + Idx int +} + +func CreateErrorObjectWithPath(message string, path []PathElement) *fastjson.Value { + errorObject := fastjson.MustParse(fmt.Sprintf(`{"message":"%s"}`, message)) + if len(path) == 0 { + return errorObject + } + errorPath := fastjson.MustParse(`[]`) + for i := range path { + if path[i].Name != "" { + errorPath.SetArrayItem(i, fastjson.MustParse(fmt.Sprintf(`"%s"`, path[i].Name))) + } else { + errorPath.SetArrayItem(i, fastjson.MustParse(strconv.FormatInt(int64(path[i].Idx), 10))) + } + } + errorObject.Set("path", errorPath) + return errorObject +} + +func PrintGraphQLResponse(data, errors *fastjson.Value) string { + out := fastjson.MustParse(`{}`) + if ValueIsNonNull(errors) { + out.Set("errors", errors) + } + out.Set("data", data) + return string(out.MarshalTo(nil)) +} diff --git a/v2/pkg/fastjsonext/fastjsonext_test.go b/v2/pkg/fastjsonext/fastjsonext_test.go new file mode 100644 index 000000000..15817bdc2 --- /dev/null +++ b/v2/pkg/fastjsonext/fastjsonext_test.go @@ -0,0 +1,116 @@ +package fastjsonext + +import ( + "testing" + + "github.com/stretchr/testify/require" + "github.com/valyala/fastjson" +) + +func TestMergeValues(t *testing.T) { + a, b := fastjson.MustParse(`{"a":1}`), fastjson.MustParse(`{"b":2}`) + merged, changed := MergeValues(a, b) + require.Equal(t, false, changed) + out := merged.MarshalTo(nil) + require.Equal(t, `{"a":1,"b":2}`, string(out)) + out = merged.Get("b").MarshalTo(out[:0]) + require.Equal(t, `2`, string(out)) +} + +func TestMergeValuesArray(t *testing.T) { + a, b := fastjson.MustParse(`[1,2]`), fastjson.MustParse(`[3,4]`) + merged, changed := MergeValues(a, b) + require.Equal(t, false, changed) + out := merged.MarshalTo(nil) + require.Equal(t, `[1,2,3,4]`, string(out)) +} + +func TestMergeValuesNestedObjects(t *testing.T) { + a, b := fastjson.MustParse(`{"a":{"b":1}}`), fastjson.MustParse(`{"a":{"c":2}}`) + merged, changed := MergeValues(a, b) + require.Equal(t, false, changed) + out := merged.MarshalTo(nil) + require.Equal(t, `{"a":{"b":1,"c":2}}`, string(out)) +} + +func TestMergeValuesWithPath(t *testing.T) { + a, b := fastjson.MustParse(`{"a":{"b":1}}`), fastjson.MustParse(`{"c":2}`) + merged, changed := MergeValuesWithPath(a, b, "a") + require.Equal(t, false, changed) + out := merged.MarshalTo(nil) + require.Equal(t, `{"a":{"b":1,"c":2}}`, string(out)) + e := fastjson.MustParse(`{"e":true}`) + merged, changed = MergeValuesWithPath(merged, e, "a", "d") + require.Equal(t, false, changed) + out = merged.MarshalTo(out[:0]) + require.Equal(t, `{"a":{"b":1,"c":2,"d":{"e":true}}}`, string(out)) +} + +func TestGetArray(t *testing.T) { + a := fastjson.MustParse(`[{"name":"Jens"},{"name":"Jannik"}]`) + arr, err := a.Array() + require.NoError(t, err) + require.Equal(t, 2, len(arr)) + jens := arr[0].MarshalTo(nil) + require.Equal(t, `{"name":"Jens"}`, string(jens)) + jannik := arr[1].MarshalTo(nil) + require.Equal(t, `{"name":"Jannik"}`, string(jannik)) +} + +func TestSetNull(t *testing.T) { + a := fastjson.MustParse(`{"name":"Jens"}`) + SetNull(a, "name") + out := a.MarshalTo(nil) + require.Equal(t, `{"name":null}`, string(out)) + + b := fastjson.MustParse(`{"person":{"name":"Jens"}}`) + SetNull(b, "person", "name") + out = b.MarshalTo(nil) + require.Equal(t, `{"person":{"name":null}}`, string(out)) +} + +func TestSetWithNonExistingPath(t *testing.T) { + a := fastjson.MustParse(`{}`) + SetValue(a, fastjson.MustParse(`1`), "a", "b") + out := a.MarshalTo(nil) + require.Equal(t, `{"a":{"b":1}}`, string(out)) +} + +func TestAppendErrorWithMessage(t *testing.T) { + a := fastjson.MustParse(`[]`) + AppendErrorToArray(a, "error", nil) + out := a.MarshalTo(nil) + require.Equal(t, `[{"message":"error"}]`, string(out)) + + AppendErrorToArray(a, "error2", []PathElement{{Name: "a"}}) + out = a.MarshalTo(nil) + require.Equal(t, `[{"message":"error"},{"message":"error2","path":["a"]}]`, string(out)) +} + +func TestCreateErrorObjectWithPath(t *testing.T) { + v := CreateErrorObjectWithPath("my error message", []PathElement{ + {Name: "a"}, + }) + out := v.MarshalTo(nil) + require.Equal(t, `{"message":"my error message","path":["a"]}`, string(out)) + v = CreateErrorObjectWithPath("my error message", []PathElement{ + {Name: "a"}, + {Idx: 1}, + {Name: "b"}, + }) + out = v.MarshalTo(nil) + require.Equal(t, `{"message":"my error message","path":["a",1,"b"]}`, string(out)) + v = CreateErrorObjectWithPath("my error message", []PathElement{ + {Name: "a"}, + {Name: "b"}, + }) + out = v.MarshalTo(nil) + require.Equal(t, `{"message":"my error message","path":["a","b"]}`, string(out)) +} + +func TestAppendToArray(t *testing.T) { + a := fastjson.MustParse(`[1,2]`) + AppendToArray(a, fastjson.MustParse(`3`)) + out := a.MarshalTo(nil) + require.Equal(t, `[1,2,3]`, string(out)) +}