Skip to content

Commit

Permalink
Remove prepared statements cache, refactor various things
Browse files Browse the repository at this point in the history
  • Loading branch information
jlelse committed Jul 16, 2024
1 parent c0a2254 commit 72e09b7
Show file tree
Hide file tree
Showing 11 changed files with 338 additions and 300 deletions.
210 changes: 106 additions & 104 deletions cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,111 +8,106 @@ import (
"time"

"github.com/dgraph-io/ristretto"
"go.goblog.app/app/pkgs/bodylimit"
"go.goblog.app/app/pkgs/bufferpool"
"golang.org/x/sync/singleflight"
)

const (
cacheLoggedInKey contextKey = "cacheLoggedIn"
cacheExpirationKey contextKey = "cacheExpiration"

cacheControl = "Cache-Control"
cacheControl = "Cache-Control"
)

type cache struct {
g singleflight.Group
c *ristretto.Cache
}

func (a *goBlog) initCache() (err error) {
func (a *goBlog) initCache() error {
a.cache = &cache{}
if a.cfg.Cache != nil && !a.cfg.Cache.Enable {
// Cache disabled
return nil
return nil // Cache disabled
}
a.cache.c, err = ristretto.NewCache(&ristretto.Config{
NumCounters: 40 * 1000, // 4000 items when full with 5 KB items -> x10 = 40.000
MaxCost: 20 * 1000 * 1000, // 20 MB
BufferItems: 64, // recommended

c, err := ristretto.NewCache(&ristretto.Config{
NumCounters: 40000,
MaxCost: 20 * bodylimit.MB,
BufferItems: 64,
Metrics: true,
})
go func() {
ticker := time.NewTicker(15 * time.Minute)
for range ticker.C {
met := a.cache.c.Metrics
a.info("Cache metrics", "metrics", met.String())
}
}()
return
if err != nil {
return err
}

a.cache.c = c
go a.logCacheMetrics()
return nil
}

func (a *goBlog) logCacheMetrics() {
ticker := time.NewTicker(15 * time.Minute)
defer ticker.Stop()

for range ticker.C {
met := a.cache.c.Metrics
a.info("Cache metrics", "metrics", met.String())
}
}

func cacheLoggedIn(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
next.ServeHTTP(w, r.WithContext(context.WithValue(r.Context(), cacheLoggedInKey, true)))
ctx := context.WithValue(r.Context(), cacheLoggedInKey, true)
next.ServeHTTP(w, r.WithContext(ctx))
})
}

func (a *goBlog) cacheMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Do checks
if a.cache.c == nil || !cacheable(r) {
next.ServeHTTP(w, r)
return
}
// Check login
if cli, ok := r.Context().Value(cacheLoggedInKey).(bool); ok && cli {
// Continue caching, but remove login
setLoggedIn(r, false)
} else if a.isLoggedIn(r) {
// Don't cache logged in requests
if a.cache.c == nil || !isCacheable(r) || a.shouldSkipLoggedIn(r) {
next.ServeHTTP(w, r)
return
}
// Search and serve cache
key := cacheKey(r)
// Get cache or render it
cacheInterface, _, _ := a.cache.g.Do(key, func() (any, error) {
return a.cache.getCache(key, next, r), nil

key := generateCacheKey(r)
cacheInterface, _, _ := a.cache.g.Do(key, func() (interface{}, error) {
return a.cache.getOrCreateCache(key, next, r), nil
})

ci := cacheInterface.(*cacheItem)
// copy and set headers
a.setCacheHeaders(w, ci)
// check conditional request
if ifNoneMatchHeader := r.Header.Get("If-None-Match"); ifNoneMatchHeader != "" && ifNoneMatchHeader == ci.eTag {
// send 304
w.WriteHeader(http.StatusNotModified)
return
}
// set status code
w.WriteHeader(ci.code)
// write cached body
_, _ = w.Write(ci.body)
a.serveCachedResponse(w, r, ci)
})
}

func cacheable(r *http.Request) bool {
func isCacheable(r *http.Request) bool {
if r.Method != http.MethodGet && r.Method != http.MethodHead {
return false
}
if r.URL.Query().Get("cache") == "0" || r.URL.Query().Get("cache") == "false" {
return r.URL.Query().Get("cache") != "0" && r.URL.Query().Get("cache") != "false"
}

func (a *goBlog) shouldSkipLoggedIn(r *http.Request) bool {
if cli, ok := r.Context().Value(cacheLoggedInKey).(bool); ok && cli {
setLoggedIn(r, false)
return false
}
return true
return a.isLoggedIn(r)
}

func cacheKey(r *http.Request) (key string) {
func generateCacheKey(r *http.Request) string {
buf := bufferpool.Get()
defer bufferpool.Put(buf)
// Special cases
if asRequest, ok := r.Context().Value(asRequestKey).(bool); ok && asRequest {
_, _ = buf.WriteString("as-")
buf.WriteString("as-")
}
if torUsed, ok := r.Context().Value(torUsedKey).(bool); ok && torUsed {
_, _ = buf.WriteString("tor-")
buf.WriteString("tor-")
}
// Add cache URL
_, _ = buf.WriteString(r.URL.EscapedPath())
buf.WriteString(r.URL.EscapedPath())
if query := r.URL.Query(); len(query) > 0 {
_ = buf.WriteByte('?')
buf.WriteByte('?')
keys := make([]string, 0, len(query))
for k := range query {
keys = append(keys, k)
Expand All @@ -130,11 +125,20 @@ func cacheKey(r *http.Request) (key string) {
}
}
}
// Get key as string
key = buf.String()
// Return buffer to pool
bufferpool.Put(buf)
return

return buf.String()
}

func (a *goBlog) serveCachedResponse(w http.ResponseWriter, r *http.Request, ci *cacheItem) {
a.setCacheHeaders(w, ci)

if ifNoneMatchHeader := r.Header.Get("If-None-Match"); ifNoneMatchHeader != "" && ifNoneMatchHeader == ci.eTag {
w.WriteHeader(http.StatusNotModified)
return
}

w.WriteHeader(ci.code)
_, _ = w.Write(ci.body)
}

func (a *goBlog) setCacheHeaders(w http.ResponseWriter, cache *cacheItem) {
Expand All @@ -147,62 +151,60 @@ func (a *goBlog) setCacheHeaders(w http.ResponseWriter, cache *cacheItem) {
w.Header().Set(cacheControl, "public,no-cache")
}

type cacheItem struct {
expiration int
eTag string
code int
header http.Header
body []byte
}

// Calculate byte size of cache item using size of header, body and etag
func (ci *cacheItem) cost() int {
headerBuf := bufferpool.Get()
_ = ci.header.Write(headerBuf)
headerSize := len(headerBuf.Bytes())
bufferpool.Put(headerBuf)
return headerSize + len(ci.body) + len(ci.eTag)
}

func (c *cache) getCache(key string, next http.Handler, r *http.Request) *cacheItem {
func (c *cache) getOrCreateCache(key string, next http.Handler, r *http.Request) *cacheItem {
if rItem, ok := c.c.Get(key); ok {
return rItem.(*cacheItem)
}
// No cache available
// Make and use copy of r
//nolint:contextcheck
cr := r.Clone(valueOnlyContext{r.Context()})
// Remove problematic headers
cr.Header.Del("If-Modified-Since")
cr.Header.Del("If-Unmodified-Since")
cr.Header.Del("If-None-Match")
cr.Header.Del("If-Match")
cr.Header.Del("If-Range")
cr.Header.Del("Range")
// Record request

// Remove original timeout, add new one
withoutCancelCtx := context.WithoutCancel(r.Context())
newCancelCtx, cancel := context.WithTimeout(withoutCancelCtx, 5*time.Minute)
defer cancel()

cr := r.Clone(newCancelCtx)
removeConditionalHeaders(cr)

rec := newCacheRecorder()
next.ServeHTTP(rec, cr)
item := rec.finish()
// Set expiration

item.expiration, _ = cr.Context().Value(cacheExpirationKey).(int)
// Remove problematic headers
item.header.Del("Accept-Ranges")
item.header.Del("ETag")
item.header.Del("Last-Modified")
// Save cache
if cch := item.header.Get(cacheControl); !containsStrings(cch, "no-store", "private", "no-cache") {
cost := int64(item.cost())
if item.expiration == 0 {
// Cache items max. 6 hours
c.c.SetWithTTL(key, item, cost, 6*time.Hour)
} else {
c.c.SetWithTTL(key, item, cost, time.Duration(item.expiration)*time.Second)
}
c.c.Wait()
removeProblematicHeaders(item.header)

if shouldCacheItem(item.header.Get(cacheControl)) {
c.saveCache(key, item)
}

return item
}

func removeConditionalHeaders(r *http.Request) {
headers := []string{"If-Modified-Since", "If-Unmodified-Since", "If-None-Match", "If-Match", "If-Range", "Range"}
for _, h := range headers {
r.Header.Del(h)
}
}

func removeProblematicHeaders(header http.Header) {
headers := []string{"Accept-Ranges", "ETag", "Last-Modified"}
for _, h := range headers {
header.Del(h)
}
}

func shouldCacheItem(cacheControlHeader string) bool {
return !containsStrings(cacheControlHeader, "no-store", "private", "no-cache")
}

func (c *cache) saveCache(key string, item *cacheItem) {
ttl := 6 * time.Hour
if item.expiration > 0 {
ttl = time.Duration(item.expiration) * time.Second
}
c.c.SetWithTTL(key, item, item.cost(), ttl)
c.c.Wait()
}

func (c *cache) purge() {
if c == nil {
return
Expand Down
54 changes: 48 additions & 6 deletions cacheRecorder.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,24 +4,42 @@ import (
"crypto/sha256"
"fmt"
"net/http"
"sync"
"unsafe"
)

// cacheRecorder is an implementation of http.ResponseWriter
// cacheRecorder is a thread-safe implementation of http.ResponseWriter
type cacheRecorder struct {
mu sync.Mutex
item cacheItem
done bool
}

type cacheItem struct {
expiration int
eTag string
code int
header http.Header
body []byte
}

func newCacheRecorder() *cacheRecorder {
return &cacheRecorder{
item: cacheItem{
code: http.StatusOK,
header: http.Header{},
header: make(http.Header),
},
}
}

func (c *cacheRecorder) finish() *cacheItem {
c.mu.Lock()
defer c.mu.Unlock()

if c.done {
return &c.item
}

c.done = true
c.item.eTag = c.item.header.Get("ETag")
if c.item.eTag == "" {
Expand All @@ -32,6 +50,9 @@ func (c *cacheRecorder) finish() *cacheItem {

// Header implements http.ResponseWriter.
func (c *cacheRecorder) Header() http.Header {
c.mu.Lock()
defer c.mu.Unlock()

if c.done {
return nil
}
Expand All @@ -40,25 +61,46 @@ func (c *cacheRecorder) Header() http.Header {

// Write implements http.ResponseWriter.
func (c *cacheRecorder) Write(buf []byte) (int, error) {
c.mu.Lock()
defer c.mu.Unlock()

if c.done {
return 0, nil
return 0, fmt.Errorf("write after finish")
}
c.item.body = append(c.item.body, buf...)
return len(buf), nil
}

// WriteString implements io.StringWriter.
func (c *cacheRecorder) WriteString(str string) (int, error) {
if c.done {
return 0, nil
}
return c.Write([]byte(str))
}

// WriteHeader implements http.ResponseWriter.
func (c *cacheRecorder) WriteHeader(code int) {
c.mu.Lock()
defer c.mu.Unlock()

if c.done {
return
}
c.item.code = code
}

func (ci *cacheItem) cost() int64 {
size := int64(unsafe.Sizeof(*ci)) // Base struct size

// Add sizes of variable-length fields
size += int64(len(ci.eTag))
size += int64(len(ci.body))

// Calculate header size
for key, values := range ci.header {
size += int64(len(key))
for _, value := range values {
size += int64(len(value))
}
}

return size
}
Loading

0 comments on commit 72e09b7

Please sign in to comment.