Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Changes to support raw json fields #41

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,5 @@ install:
script:
- go get -t -v ./...
- diff -u <(echo -n) <(gofmt -d -s .)
- go tool vet .
- go vet .
- go test -v -race ./...
50 changes: 50 additions & 0 deletions graphql_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package graphql_test

import (
"context"
"encoding/json"
"io"
"io/ioutil"
"net/http"
Expand Down Expand Up @@ -64,6 +65,55 @@ func TestClient_Query_partialDataWithErrorResponse(t *testing.T) {
}
}

func TestClient_Query_partialDataRawQueryWithErrorResponse(t *testing.T) {
mux := http.NewServeMux()
mux.HandleFunc("/graphql", func(w http.ResponseWriter, req *http.Request) {
w.Header().Set("Content-Type", "application/json")
mustWrite(w, `{
"data": {
"node1": { "id": "MDEyOklzc3VlQ29tbWVudDE2OTQwNzk0Ng==" },
"node2": null
},
"errors": [
{
"message": "Could not resolve to a node with the global id of 'NotExist'",
"type": "NOT_FOUND",
"path": [
"node2"
],
"locations": [
{
"line": 10,
"column": 4
}
]
}
]
}`)
})
client := graphql.NewClient("/graphql", &http.Client{Transport: localRoundTripper{handler: mux}})

var q struct {
Node1 json.RawMessage `graphql:"node1"`
Node2 *struct {
ID graphql.ID
} `graphql:"node2: node(id: \"NotExist\")"`
}
err := client.Query(context.Background(), &q, nil)
if err == nil {
t.Fatal("got error: nil, want: non-nil\n")
}
if got, want := err.Error(), "Could not resolve to a node with the global id of 'NotExist'"; got != want {
t.Errorf("got error: %v, want: %v\n", got, want)
}
if q.Node1 == nil || string(q.Node1) != `{"id":"MDEyOklzc3VlQ29tbWVudDE2OTQwNzk0Ng=="}` {
t.Errorf("got wrong q.Node1: %v\n", string(q.Node1))
}
if q.Node2 != nil {
t.Errorf("got non-nil q.Node2: %v, want: nil\n", *q.Node2)
}
}

func TestClient_Query_noDataWithErrorResponse(t *testing.T) {
mux := http.NewServeMux()
mux.HandleFunc("/graphql", func(w http.ResponseWriter, req *http.Request) {
Expand Down
39 changes: 29 additions & 10 deletions internal/jsonutil/graphql.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ func UnmarshalGraphQL(data []byte, v interface{}) error {
type decoder struct {
tokenizer interface {
Token() (json.Token, error)
Decode(v interface{}) error
}

// Stack of what part of input JSON we're in the middle of - objects, arrays.
Expand All @@ -68,10 +69,14 @@ func (d *decoder) Decode(v interface{}) error {

// decode decodes a single JSON value from d.tokenizer into d.vs.
func (d *decoder) decode() error {
rawMessageValue := reflect.ValueOf(json.RawMessage{})

// The loop invariant is that the top of each d.vs stack
// is where we try to unmarshal the next JSON value we see.
for len(d.vs) > 0 {
var tok interface{}
tok, err := d.tokenizer.Token()

if err == io.EOF {
return errors.New("unexpected end of JSON input")
} else if err != nil {
Expand All @@ -87,6 +92,8 @@ func (d *decoder) decode() error {
return errors.New("unexpected non-key in JSON input")
}
someFieldExist := false
// If one field is raw all must be treated as raw
rawMessage := false
for i := range d.vs {
v := d.vs[i][len(d.vs[i])-1]
if v.Kind() == reflect.Ptr {
Expand All @@ -97,24 +104,36 @@ func (d *decoder) decode() error {
f = fieldByGraphQLName(v, key)
if f.IsValid() {
someFieldExist = true
// Check for special embedded json
if f.Type() == rawMessageValue.Type() {
rawMessage = true
}
}

}
d.vs[i] = append(d.vs[i], f)
}
if !someFieldExist {
return fmt.Errorf("struct field for %q doesn't exist in any of %v places to unmarshal", key, len(d.vs))
}

// We've just consumed the current token, which was the key.
// Read the next token, which should be the value, and let the rest of code process it.
tok, err = d.tokenizer.Token()
if err == io.EOF {
return errors.New("unexpected end of JSON input")
} else if err != nil {
return err
if rawMessage {
// Read the next complete object from the json stream
var data json.RawMessage
d.tokenizer.Decode(&data)
tok = data
} else {
// We've just consumed the current token, which was the key.
// Read the next token, which should be the value, and let the rest of code process it.
tok, err = d.tokenizer.Token()
if err == io.EOF {
return errors.New("unexpected end of JSON input")
} else if err != nil {
return err
}
}

// Are we inside an array and seeing next value (rather than end of array)?
// Are we inside an array and seeing next value (rather than end of array)?
case d.state() == '[' && tok != json.Delim(']'):
someSliceExist := false
for i := range d.vs {
Expand All @@ -136,7 +155,7 @@ func (d *decoder) decode() error {
}

switch tok := tok.(type) {
case string, json.Number, bool, nil:
case string, json.Number, bool, nil, json.RawMessage:
// Value.

for i := range d.vs {
Expand Down Expand Up @@ -302,7 +321,7 @@ func isGraphQLFragment(f reflect.StructField) bool {
// unmarshalValue unmarshals JSON value into v.
// v must be addressable and not obtained by the use of unexported
// struct fields, otherwise unmarshalValue will panic.
func unmarshalValue(value json.Token, v reflect.Value) error {
func unmarshalValue(value interface{}, v reflect.Value) error {
b, err := json.Marshal(value) // TODO: Short-circuit (if profiling says it's worth it).
if err != nil {
return err
Expand Down
24 changes: 24 additions & 0 deletions internal/jsonutil/graphql_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package jsonutil_test

import (
"encoding/json"
"reflect"
"testing"
"time"
Expand Down Expand Up @@ -80,6 +81,29 @@ func TestUnmarshalGraphQL_jsonTag(t *testing.T) {
}
}

func TestUnmarshalGraphQL_jsonRawTag(t *testing.T) {
type query struct {
Data json.RawMessage
Another string
}
var got query
err := jsonutil.UnmarshalGraphQL([]byte(`{
"Data": { "foo":"bar" },
"Another" : "stuff"
}`), &got)

if err != nil {
t.Fatal(err)
}
want := query{
Another: "stuff",
Data: []byte(`{"foo":"bar"}`),
}
if !reflect.DeepEqual(got, want) {
t.Errorf("not equal: %v %v", want, got)
}
}

func TestUnmarshalGraphQL_array(t *testing.T) {
type query struct {
Foo []graphql.String
Expand Down