Skip to content

Commit

Permalink
feat: lex-sort query args and arg transformer hook
Browse files Browse the repository at this point in the history
This commit introduces a new feature that enables the specification
of specific query parameters to be used as the cache index key
when creating a URI via a arg transformer hook.

For instance, it allows caching requests like GET /foo?mode=bar. The
assumption is that users are aware of the query parameters that can
be sent and that these parameters are utilized by the handler to modify
the request. Consequently, even if the user includes additional
irrelevant parameters like GET /foo?mode=bar&junk=xyz, the cache will
still be utilized since the specified parameters are the key.

Also, we now lexicographically sort the query params so that cache
cannot be busted by reordering the params.
  • Loading branch information
rhnvrm committed Mar 6, 2024
1 parent 1d2547a commit 803d284
Show file tree
Hide file tree
Showing 3 changed files with 213 additions and 29 deletions.
59 changes: 43 additions & 16 deletions fastcache.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ type Options struct {

// Cache based on uri+querystring.
IncludeQueryString bool

QueryArgsTransformerHook func(*fasthttp.Args)
}

// Item represents the cache entry for a single endpoint with the actual cache
Expand Down Expand Up @@ -94,14 +96,8 @@ func (f *FastCache) Cached(h fastglue.FastRequestHandler, o *Options, group stri
}
return h(r)
}
var hash [16]byte
// If IncludeQueryString option is set then cache based on uri + md5(query_string)
if o.IncludeQueryString {
hash = md5.Sum(r.RequestCtx.URI().FullURI())
} else {
hash = md5.Sum(r.RequestCtx.URI().Path())
}
uri := hex.EncodeToString(hash[:])

uri := f.makeURI(r, o)

// Fetch etag + cached bytes from the store.
blob, err := f.s.Get(namespace, group, uri)
Expand Down Expand Up @@ -193,6 +189,44 @@ func (f *FastCache) DelGroup(namespace string, group ...string) error {
return f.s.DelGroup(namespace, group...)
}

func (f *FastCache) makeURI(r *fastglue.Request, o *Options) string {
var hash [16]byte

// lexicographically sort the query string.
r.RequestCtx.QueryArgs().Sort(func(x, y []byte) int {
return bytes.Compare(x, y)
})

// If IncludeQueryString option is set then cache based on uri + md5(query_string)
if o.IncludeQueryString {
id := r.RequestCtx.URI().FullURI()

// Check if we need to include only specific query params.
if o.QueryArgsTransformerHook != nil {
// Acquire a copy so as to not modify the request.
uriRaw := fasthttp.AcquireURI()
r.RequestCtx.URI().CopyTo(uriRaw)

q := uriRaw.QueryArgs()

// Call the hook to transform the query string.
o.QueryArgsTransformerHook(q)

// Get the new URI.
id = uriRaw.FullURI()

// Release the borrowed URI.
fasthttp.ReleaseURI(uriRaw)
}

hash = md5.Sum(id)
} else {
hash = md5.Sum(r.RequestCtx.URI().Path())
}

return hex.EncodeToString(hash[:])
}

// cache caches a response body.
func (f *FastCache) cache(r *fastglue.Request, namespace, group string, o *Options) error {
// ETag?.
Expand All @@ -206,14 +240,7 @@ func (f *FastCache) cache(r *fastglue.Request, namespace, group string, o *Optio
}

// Write cache to the store (etag, content type, response body).
var hash [16]byte
// If IncludeQueryString option is set then cache based on uri + md5(query_string)
if o.IncludeQueryString {
hash = md5.Sum(r.RequestCtx.URI().FullURI())
} else {
hash = md5.Sum(r.RequestCtx.URI().Path())
}
uri := hex.EncodeToString(hash[:])
uri := f.makeURI(r, o)

var blob []byte
if !o.NoBlob {
Expand Down
164 changes: 159 additions & 5 deletions fastcache_test.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
package fastcache_test

import (
"io/ioutil"
"fmt"
"io"
"log"
"net/http"
"os"
Expand Down Expand Up @@ -50,6 +51,49 @@ func init() {
Logger: log.New(os.Stdout, "", log.Ldate|log.Ltime|log.Lshortfile),
}

includeQS = &fastcache.Options{
NamespaceKey: namespaceKey,
ETag: true,
TTL: time.Second * 5,
Logger: log.New(os.Stdout, "", log.Ldate|log.Ltime|log.Lshortfile),
IncludeQueryString: true,
}

includeQSNoEtag = &fastcache.Options{
NamespaceKey: namespaceKey,
ETag: false,
TTL: time.Second * 5,
Logger: log.New(os.Stdout, "", log.Ldate|log.Ltime|log.Lshortfile),
IncludeQueryString: true,
}

includeQSSpecific = &fastcache.Options{
NamespaceKey: namespaceKey,
ETag: true,
TTL: time.Second * 5,
Logger: log.New(os.Stdout, "", log.Ldate|log.Ltime|log.Lshortfile),
IncludeQueryString: true,
QueryArgsTransformerHook: func(args *fasthttp.Args) {
// Copy the keys to delete, and delete them later. This is to
// avoid borking the VisitAll() iterator.
mp := map[string]struct{}{
"foo": {},
}

delKeys := [][]byte{}
args.VisitAll(func(k, v []byte) {
if _, ok := mp[string(k)]; !ok {
delKeys = append(delKeys, k)
}
})

// Delete the keys.
for _, k := range delKeys {
args.DelBytes(k)
}
},
}

fc = fastcache.New(cachestore.New("CACHE:", redis.NewClient(&redis.Options{
Addr: rd.Addr(),
})))
Expand Down Expand Up @@ -78,6 +122,19 @@ func init() {
return r.SendBytes(200, "text/plain", []byte("ok"))
}, ttlShort, group))

srv.GET("/include-qs", fc.Cached(func(r *fastglue.Request) error {
return r.SendBytes(200, "text/plain", []byte("ok"))
}, includeQS, group))

srv.GET("/include-qs-no-etag", fc.Cached(func(r *fastglue.Request) error {
out := time.Now()
return r.SendBytes(200, "text/plain", []byte(fmt.Sprintf("%v", out)))
}, includeQSNoEtag, group))

srv.GET("/include-qs-specific", fc.Cached(func(r *fastglue.Request) error {
return r.SendBytes(200, "text/plain", []byte("ok"))
}, includeQSSpecific, group))

// Start the server
go func() {
s := &fasthttp.Server{
Expand Down Expand Up @@ -111,7 +168,7 @@ func getReq(url, etag string, t *testing.T) (*http.Response, string) {
t.Fatal(err)
}

b, err := ioutil.ReadAll(resp.Body)
b, err := io.ReadAll(resp.Body)

Check failure on line 171 in fastcache_test.go

View workflow job for this annotation

GitHub Actions / Go 1.13 Tests

undefined: io.ReadAll

Check failure on line 171 in fastcache_test.go

View workflow job for this annotation

GitHub Actions / Go 1.14 Tests

undefined: io.ReadAll

Check failure on line 171 in fastcache_test.go

View workflow job for this annotation

GitHub Actions / Go 1.15 Tests

undefined: io.ReadAll

Check failure on line 171 in fastcache_test.go

View workflow job for this annotation

GitHub Actions / Go 1.13 Tests

undefined: io.ReadAll

Check failure on line 171 in fastcache_test.go

View workflow job for this annotation

GitHub Actions / Go 1.14 Tests

undefined: io.ReadAll

Check failure on line 171 in fastcache_test.go

View workflow job for this annotation

GitHub Actions / Go 1.15 Tests

undefined: io.ReadAll
if err != nil {
t.Fatal(b)
}
Expand Down Expand Up @@ -139,22 +196,119 @@ func TestCache(t *testing.T) {
}

// Wrong etag.
r, b = getReq(srvRoot+"/cached", "wrong", t)
r, _ = getReq(srvRoot+"/cached", "wrong", t)
if r.StatusCode != 200 {
t.Fatalf("expected 200 but got '%v'", r.StatusCode)
}

// Clear cache.
r, b = getReq(srvRoot+"/clear-group", "", t)
r, _ = getReq(srvRoot+"/clear-group", "", t)
if r.StatusCode != 200 {
t.Fatalf("expected 200 but got %v", r.StatusCode)
}
r, b = getReq(srvRoot+"/cached", r.Header.Get("Etag"), t)
r, _ = getReq(srvRoot+"/cached", r.Header.Get("Etag"), t)
if r.StatusCode != 200 {
t.Fatalf("expected 200 but got '%v'", r.StatusCode)
}
}

func TestQueryString(t *testing.T) {
// First request should be 200.
r, b := getReq(srvRoot+"/include-qs?foo=bar", "", t)
if r.StatusCode != 200 {
t.Fatalf("expected 200 but got %v", r.StatusCode)
}

if b != "ok" {
t.Fatalf("expected 'ok' in body but got %v", b)
}

// Second should be 304.
r, _ = getReq(srvRoot+"/include-qs?foo=bar", r.Header.Get("Etag"), t)
if r.StatusCode != 304 {
t.Fatalf("expected 304 but got '%v'", r.StatusCode)
}
}

func TestQueryStringLexicographical(t *testing.T) {
// First request should be 200.
r, b := getReq(srvRoot+"/include-qs?foo=bar&baz=qux", "", t)
if r.StatusCode != 200 {
t.Fatalf("expected 200 but got %v", r.StatusCode)
}

if b != "ok" {
t.Fatalf("expected 'ok' in body but got %v", b)
}

// Second should be 304.
r, _ = getReq(srvRoot+"/include-qs?baz=qux&foo=bar", r.Header.Get("Etag"), t)
if r.StatusCode != 304 {
t.Fatalf("expected 304 but got '%v'", r.StatusCode)
}
}

func TestQueryStringWithoutEtag(t *testing.T) {
// First request should be 200.
r, b := getReq(srvRoot+"/include-qs-no-etag?foo=bar", "", t)
if r.StatusCode != 200 {
t.Fatalf("expected 200 but got %v", r.StatusCode)
}

// Second should be 200 but with same response.
r2, b2 := getReq(srvRoot+"/include-qs-no-etag?foo=bar", "", t)
if r2.StatusCode != 200 {
t.Fatalf("expected 200 but got '%v'", r2.StatusCode)
}

if b2 != b {
t.Fatalf("expected '%v' in body but got %v", b, b2)
}

// Third should be 200 but with different response.
r3, b3 := getReq(srvRoot+"/include-qs-no-etag?foo=baz", "", t)
if r3.StatusCode != 200 {
t.Fatalf("expected 200 but got '%v'", r3.StatusCode)
}

// time should be different
if b3 == b {
t.Fatalf("expected both to be different (should not be %v), but got %v", b, b3)
}
}

func TestQueryStringSpecific(t *testing.T) {
// First request should be 200.
r1, b := getReq(srvRoot+"/include-qs-specific?foo=bar&baz=qux", "", t)
if r1.StatusCode != 200 {
t.Fatalf("expected 200 but got %v", r1.StatusCode)
}
if b != "ok" {
t.Fatalf("expected 'ok' in body but got %v", b)
}

// Second should be 304.
r, _ := getReq(srvRoot+"/include-qs-specific?foo=bar&baz=qux", r1.Header.Get("Etag"), t)
if r.StatusCode != 304 {
t.Fatalf("expected 304 but got '%v'", r.StatusCode)
}

// Third should be 304 as foo=bar
r, _ = getReq(srvRoot+"/include-qs-specific?loo=mar&foo=bar&baz=qux&quux=quuz", r1.Header.Get("Etag"), t)
if r.StatusCode != 304 {
t.Fatalf("expected 304 but got '%v'", r.StatusCode)
}

// Fourth should be 200 as foo=rab
r, b = getReq(srvRoot+"/include-qs-specific?foo=rab&baz=qux&quux=quuz", r1.Header.Get("Etag"), t)
if r.StatusCode != 200 {
t.Fatalf("expected 200 but got '%v'", r.StatusCode)
}
if b != "ok" {
t.Fatalf("expected 'ok' in body but got %v", b)
}
}

func TestNoCache(t *testing.T) {
// All requests should return 200.
for n := 0; n < 3; n++ {
Expand Down
19 changes: 11 additions & 8 deletions stores/goredis/redis.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,16 @@
// The internal structure looks like this where
// XX1234 = namespace, marketwach = group
// ```
// CACHE:XX1234:marketwatch {
// "/user/marketwatch_ctype" -> []byte
// "/user/marketwatch_etag" -> []byte
// "/user/marketwatch_blob" -> []byte
// "/user/marketwatch/123_ctype" -> []byte
// "/user/marketwatch/123_etag" -> []byte
// "/user/marketwatch/123_blob" -> []byte
// }
//
// CACHE:XX1234:marketwatch {
// "/user/marketwatch_ctype" -> []byte
// "/user/marketwatch_etag" -> []byte
// "/user/marketwatch_blob" -> []byte
// "/user/marketwatch/123_ctype" -> []byte
// "/user/marketwatch/123_etag" -> []byte
// "/user/marketwatch/123_blob" -> []byte
// }
//
// ```
package goredis

Expand Down Expand Up @@ -53,6 +55,7 @@ func (s *Store) Get(namespace, group, uri string) (fastcache.Item, error) {
var (
out fastcache.Item
)

// Get content_type, etag, blob in that order.
cmd := s.cn.HMGet(s.ctx, s.key(namespace, group), s.field(keyCtype, uri), s.field(keyEtag, uri), s.field(keyBlob, uri))
if err := cmd.Err(); err != nil {
Expand Down

0 comments on commit 803d284

Please sign in to comment.