diff --git a/docs/content/reference/apq.md b/docs/content/reference/apq.md new file mode 100644 index 00000000000..1c978db4a4b --- /dev/null +++ b/docs/content/reference/apq.md @@ -0,0 +1,77 @@ +--- +title: "Automatic persisted queries" +description: +linkTitle: "APQ" +menu: { main: { parent: 'reference' } } +--- + +When you work with GraphQL by default your queries are transferred with every request. That can waste significant +bandwidth. To avoid that you can use Automatic Persisted Queries (APQ). + +With APQ you send only query hash to the server. If hash is not found on a server then client makes a second request +to register query hash with original query on a server. + +## Usage + +In order to enable Automatic Persisted Queries you need to change your client. For more information see +[Automatic Persisted Queries Link](https://github.com/apollographql/apollo-link-persisted-queries) documentation. + +For the server you need to implement `PersistedQueryCache` interface and pass instance to +`handler.EnablePersistedQueryCache` option. + +See example using [go-redis](github.com/go-redis/redis) package below: +```go +import ( + "context" + "time" + + "github.com/go-redis/redis" + "github.com/pkg/errors" +) + +type Cache struct { + client redis.UniversalClient + ttl time.Duration +} + +const apqPrefix = "apq:" + +func NewCache(redisAddress string, password string, ttl time.Duration) (*Cache, error) { + client := redis.NewClient(&redis.Options{ + Addr: redisAddress, + }) + + err := client.Ping().Err() + if err != nil { + return nil, errors.WithStack(err) + } + + return &Cache{client: client, ttl: ttl}, nil +} + +func (c *Cache) Add(ctx context.Context, hash string, query string) { + c.client.Set(apqPrefix + hash, query, c.ttl) +} + +func (c *Cache) Get(ctx context.Context, hash string) (string, bool) { + s, err := c.client.Get(apqPrefix + hash).Result() + if err != nil { + return "", false + } + return s, true +} + +func main() { + cache, err := NewCache(cfg.RedisAddress, 24*time.Hour) + if err != nil { + log.Fatalf("cannot create APQ redis cache: %v", err) + } + + c := Config{ Resolvers: &resolvers{} } + gqlHandler := handler.GraphQL( + blog.NewExecutableSchema(c), + handler.EnablePersistedQueryCache(cache), + ) + http.Handle("/query", gqlHandler) +} +``` diff --git a/handler/graphql.go b/handler/graphql.go index 8c3882ce9c2..289901f0f0f 100644 --- a/handler/graphql.go +++ b/handler/graphql.go @@ -2,6 +2,8 @@ package handler import ( "context" + "crypto/sha256" + "encoding/hex" "encoding/json" "errors" "fmt" @@ -28,6 +30,26 @@ type params struct { Query string `json:"query"` OperationName string `json:"operationName"` Variables map[string]interface{} `json:"variables"` + Extensions *extensions `json:"extensions"` +} + +type extensions struct { + PersistedQuery *persistedQuery `json:"persistedQuery"` +} + +type persistedQuery struct { + Sha256 string `json:"sha256Hash"` + Version int64 `json:"version"` +} + +const ( + errPersistedQueryNotSupported = "PersistedQueryNotSupported" + errPersistedQueryNotFound = "PersistedQueryNotFound" +) + +type PersistedQueryCache interface { + Add(ctx context.Context, hash string, query string) + Get(ctx context.Context, hash string) (string, bool) } type websocketInitFunc func(ctx context.Context, initPayload InitPayload) error @@ -47,6 +69,7 @@ type Config struct { connectionKeepAlivePingInterval time.Duration uploadMaxMemory int64 uploadMaxSize int64 + apqCache PersistedQueryCache } func (c *Config) newRequestContext(es graphql.ExecutableSchema, doc *ast.QueryDocument, op *ast.OperationDefinition, query string, variables map[string]interface{}) *graphql.RequestContext { @@ -296,6 +319,13 @@ func WebsocketKeepAliveDuration(duration time.Duration) Option { } } +// Add cache that will hold queries for automatic persisted queries (APQ) +func EnablePersistedQueryCache(cache PersistedQueryCache) Option { + return func(cfg *Config) { + cfg.apqCache = cache + } +} + const DefaultCacheSize = 1000 const DefaultConnectionKeepAlivePingInterval = 25 * time.Second @@ -355,6 +385,11 @@ type graphqlHandler struct { exec graphql.ExecutableSchema } +func computeQueryHash(query string) string { + b := sha256.Sum256([]byte(query)) + return hex.EncodeToString(b[:]) +} + func (gh *graphqlHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { if r.Method == http.MethodOptions { w.Header().Set("Allow", "OPTIONS, GET, POST") @@ -380,6 +415,13 @@ func (gh *graphqlHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { return } } + + if extensions := r.URL.Query().Get("extensions"); extensions != "" { + if err := jsonDecode(strings.NewReader(extensions), &reqParams.Extensions); err != nil { + sendErrorf(w, http.StatusBadRequest, "extensions could not be decoded") + return + } + } case http.MethodPost: mediaType, _, err := mime.ParseMediaType(r.Header.Get("Content-Type")) if err != nil { @@ -420,6 +462,41 @@ func (gh *graphqlHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { ctx := r.Context() + var queryHash string + apqRegister := false + apq := reqParams.Extensions != nil && reqParams.Extensions.PersistedQuery != nil + if apq { + // client has enabled apq + queryHash = reqParams.Extensions.PersistedQuery.Sha256 + if gh.cfg.apqCache == nil { + // server has disabled apq + sendErrorf(w, http.StatusOK, errPersistedQueryNotSupported) + return + } + if reqParams.Extensions.PersistedQuery.Version != 1 { + sendErrorf(w, http.StatusOK, "Unsupported persisted query version") + return + } + if reqParams.Query == "" { + // client sent optimistic query hash without query string + query, ok := gh.cfg.apqCache.Get(ctx, queryHash) + if !ok { + sendErrorf(w, http.StatusOK, errPersistedQueryNotFound) + return + } + reqParams.Query = query + } else { + if computeQueryHash(reqParams.Query) != queryHash { + sendErrorf(w, http.StatusOK, "provided sha does not match query") + return + } + apqRegister = true + } + } else if reqParams.Query == "" { + sendErrorf(w, http.StatusUnprocessableEntity, "Must provide query string") + return + } + var doc *ast.QueryDocument var cacheHit bool if gh.cache != nil { @@ -474,6 +551,11 @@ func (gh *graphqlHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { return } + if apqRegister && gh.cfg.apqCache != nil { + // Add to persisted query cache + gh.cfg.apqCache.Add(ctx, queryHash, reqParams.Query) + } + switch op.Operation { case ast.Query: b, err := json.Marshal(gh.exec.Query(ctx, op)) diff --git a/handler/graphql_test.go b/handler/graphql_test.go index c97c5290ea9..bfcc11082c2 100644 --- a/handler/graphql_test.go +++ b/handler/graphql_test.go @@ -15,6 +15,7 @@ import ( "testing" "github.com/99designs/gqlgen/graphql" + lru "github.com/hashicorp/golang-lru" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/vektah/gqlparser/ast" @@ -764,3 +765,68 @@ func TestBytesRead(t *testing.T) { require.Equal(t, "0193456789", string(got)) }) } + +type memoryPersistedQueryCache struct { + cache *lru.Cache +} + +func newMemoryPersistedQueryCache(size int) (*memoryPersistedQueryCache, error) { + cache, err := lru.New(size) + return &memoryPersistedQueryCache{cache: cache}, err +} + +func (c *memoryPersistedQueryCache) Add(ctx context.Context, hash string, query string) { + c.cache.Add(hash, query) +} + +func (c *memoryPersistedQueryCache) Get(ctx context.Context, hash string) (string, bool) { + val, ok := c.cache.Get(hash) + if !ok { + return "", ok + } + return val.(string), ok +} +func TestAutomaticPersistedQuery(t *testing.T) { + cache, err := newMemoryPersistedQueryCache(1000) + require.NoError(t, err) + h := GraphQL(&executableSchemaStub{}, EnablePersistedQueryCache(cache)) + t.Run("automatic persisted query POST", func(t *testing.T) { + // normal queries should be unaffected + resp := doRequest(h, "POST", "/graphql", `{"query":"{ me { name } }"}`) + assert.Equal(t, http.StatusOK, resp.Code) + assert.Equal(t, `{"data":{"name":"test"}}`, resp.Body.String()) + + // first pass: optimistic hash without query string + resp = doRequest(h, "POST", "/graphql", `{"extensions":{"persistedQuery":{"sha256Hash":"b8d9506e34c83b0e53c2aa463624fcea354713bc38f95276e6f0bd893ffb5b88","version":1}}}`) + assert.Equal(t, http.StatusOK, resp.Code) + assert.Equal(t, `{"errors":[{"message":"PersistedQueryNotFound"}],"data":null}`, resp.Body.String()) + // second pass: query with query string and query hash + resp = doRequest(h, "POST", "/graphql", `{"query":"{ me { name } }", "extensions":{"persistedQuery":{"sha256Hash":"b8d9506e34c83b0e53c2aa463624fcea354713bc38f95276e6f0bd893ffb5b88","version":1}}}`) + assert.Equal(t, http.StatusOK, resp.Code) + assert.Equal(t, `{"data":{"name":"test"}}`, resp.Body.String()) + // future requests without query string + resp = doRequest(h, "POST", "/graphql", `{"extensions":{"persistedQuery":{"sha256Hash":"b8d9506e34c83b0e53c2aa463624fcea354713bc38f95276e6f0bd893ffb5b88","version":1}}}`) + assert.Equal(t, http.StatusOK, resp.Code) + assert.Equal(t, `{"data":{"name":"test"}}`, resp.Body.String()) + }) + + t.Run("automatic persisted query GET", func(t *testing.T) { + // normal queries should be unaffected + resp := doRequest(h, "GET", "/graphql?query={me{name}}", "") + assert.Equal(t, http.StatusOK, resp.Code) + assert.Equal(t, `{"data":{"name":"test"}}`, resp.Body.String()) + + // first pass: optimistic hash without query string + resp = doRequest(h, "GET", `/graphql?extensions={"persistedQuery":{"version":1,"sha256Hash":"b58723c4fd7ce18043ae53635b304ba6cee765a67009645b04ca01e80ce1c065"}}`, "") + assert.Equal(t, http.StatusOK, resp.Code) + assert.Equal(t, `{"errors":[{"message":"PersistedQueryNotFound"}],"data":null}`, resp.Body.String()) + // second pass: query with query string and query hash + resp = doRequest(h, "GET", `/graphql?query={me{name}}&extensions={"persistedQuery":{"sha256Hash":"b58723c4fd7ce18043ae53635b304ba6cee765a67009645b04ca01e80ce1c065","version":1}}}`, "") + assert.Equal(t, http.StatusOK, resp.Code) + assert.Equal(t, `{"data":{"name":"test"}}`, resp.Body.String()) + // future requests without query string + resp = doRequest(h, "GET", `/graphql?extensions={"persistedQuery":{"version":1,"sha256Hash":"b58723c4fd7ce18043ae53635b304ba6cee765a67009645b04ca01e80ce1c065"}}`, "") + assert.Equal(t, http.StatusOK, resp.Code) + assert.Equal(t, `{"data":{"name":"test"}}`, resp.Body.String()) + }) +}