Skip to content

Commit

Permalink
Add cache middleware to zrouter and fix ttl on combinedCache (#63)
Browse files Browse the repository at this point in the history
* Add cache middleware to zrouter

* Add headers to response

* Some improvements

* minor fix
  • Loading branch information
lucaslopezf authored Feb 20, 2024
1 parent fcd31fe commit 0a3a76b
Show file tree
Hide file tree
Showing 6 changed files with 174 additions and 3 deletions.
5 changes: 2 additions & 3 deletions pkg/zcache/combined_cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ type combinedCache struct {
appName string
}

func (c *combinedCache) Set(ctx context.Context, key string, value interface{}, _ time.Duration) error {
func (c *combinedCache) Set(ctx context.Context, key string, value interface{}, ttl time.Duration) error {
c.logger.Sugar().Debugf("set key on combined cache, key: [%s]", key)

if err := c.remoteCache.Set(ctx, key, value, c.ttl); err != nil {
Expand All @@ -35,8 +35,7 @@ func (c *combinedCache) Set(ctx context.Context, key string, value interface{},
}
}

// ttl is controlled by cache instantiation, so it does not matter here
if err := c.localCache.Set(ctx, key, value, c.ttl); err != nil {
if err := c.localCache.Set(ctx, key, value, ttl); err != nil {
c.logger.Sugar().Errorf("error setting key on combined/local cache, key: [%s], err: %s", key, err)
return err
}
Expand Down
16 changes: 16 additions & 0 deletions pkg/zcache/zcache_mock.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,29 @@ package zcache
import (
"context"
"github.com/stretchr/testify/mock"
"github.com/zondax/golem/pkg/metrics"
"time"
)

type MockZCache struct {
mock.Mock
}

func (m *MockZCache) GetStats() ZCacheStats {
args := m.Called()
return args.Get(0).(ZCacheStats)
}

func (m *MockZCache) IsNotFoundError(err error) bool {
args := m.Called(err)
return args.Bool(0)
}

func (m *MockZCache) SetupAndMonitorMetrics(appName string, metricsServer metrics.TaskMetrics, updateInterval time.Duration) []error {
args := m.Called(appName, metricsServer, updateInterval)
return args.Get(0).([]error)
}

func (m *MockZCache) Set(ctx context.Context, key string, value interface{}, ttl time.Duration) error {
args := m.Called(ctx, key, value, ttl)
return args.Error(0)
Expand Down
7 changes: 7 additions & 0 deletions pkg/zrouter/domain/cache.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
package domain

import "time"

type CacheConfig struct {
Paths map[string]time.Duration
}
70 changes: 70 additions & 0 deletions pkg/zrouter/zmiddlewares/cache.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
package zmiddlewares

import (
"fmt"
"github.com/zondax/golem/pkg/zcache"
"github.com/zondax/golem/pkg/zrouter/domain"
"go.uber.org/zap"
"net/http"
"runtime/debug"
"time"
)

const (
cacheKeyPrefix = "zrouter_cache"
)

func CacheMiddleware(cache zcache.ZCache, config domain.CacheConfig, logger *zap.SugaredLogger) func(next http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
path := r.URL.Path
fullURL := constructFullURL(r)

if ttl, found := config.Paths[path]; found {
key := constructCacheKey(fullURL)

if tryServeFromCache(w, r, cache, key) {
return
}

mrw := &metricsResponseWriter{ResponseWriter: w}
next.ServeHTTP(mrw, r) // Important: This line needs to be BEFORE setting the cache.
cacheResponseIfNeeded(mrw, r, cache, key, ttl, logger)
}
})
}
}

func constructFullURL(r *http.Request) string {
fullURL := r.URL.Path
if queryString := r.URL.RawQuery; queryString != "" {
fullURL += "?" + queryString
}
return fullURL
}

func constructCacheKey(fullURL string) string {
return fmt.Sprintf("%s:%s", cacheKeyPrefix, fullURL)
}

func tryServeFromCache(w http.ResponseWriter, r *http.Request, cache zcache.ZCache, key string) bool {
var cachedResponse []byte
err := cache.Get(r.Context(), key, &cachedResponse)
if err == nil && cachedResponse != nil {
w.Header().Set(domain.ContentTypeHeader, domain.ContentTypeApplicationJSON)
_, _ = w.Write(cachedResponse)
return true
}
return false
}

func cacheResponseIfNeeded(mrw *metricsResponseWriter, r *http.Request, cache zcache.ZCache, key string, ttl time.Duration, logger *zap.SugaredLogger) {
if mrw.status != http.StatusOK {
return
}

responseBody := mrw.Body()
if err := cache.Set(r.Context(), key, responseBody, ttl); err != nil {
logger.Errorf("Internal error when setting cache response: %v\n%s", err, debug.Stack())
}
}
13 changes: 13 additions & 0 deletions pkg/zrouter/zmiddlewares/middleware.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package zmiddlewares

import (
"bytes"
"net/http"
)

Expand All @@ -10,6 +11,7 @@ type metricsResponseWriter struct {
http.ResponseWriter
status int
written int64
body *bytes.Buffer
}

func (mrw *metricsResponseWriter) WriteHeader(statusCode int) {
Expand All @@ -18,7 +20,18 @@ func (mrw *metricsResponseWriter) WriteHeader(statusCode int) {
}

func (mrw *metricsResponseWriter) Write(p []byte) (int, error) {
if mrw.body == nil {
mrw.body = new(bytes.Buffer)
}
mrw.body.Write(p)
n, err := mrw.ResponseWriter.Write(p)
mrw.written += int64(n)
return n, err
}

func (mrw *metricsResponseWriter) Body() []byte {
if mrw.body != nil {
return mrw.body.Bytes()
}
return nil
}
66 changes: 66 additions & 0 deletions pkg/zrouter/zmiddlewares/zcache_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
package zmiddlewares

import (
"github.com/zondax/golem/pkg/zrouter/domain"
"net/http"
"net/http/httptest"
"testing"
"time"

"github.com/go-chi/chi/v5"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/zondax/golem/pkg/zcache"
"go.uber.org/zap"
)

func TestCacheMiddleware(t *testing.T) {
r := chi.NewRouter()
logger, _ := zap.NewDevelopment()
mockCache := new(zcache.MockZCache)

cacheConfig := domain.CacheConfig{Paths: map[string]time.Duration{
"/cached-path": 5 * time.Minute,
}}

r.Use(CacheMiddleware(mockCache, cacheConfig, logger.Sugar()))

// Simulate a response that should be cached
r.Get("/cached-path", func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("Test!"))
})

cachedResponseBody := []byte("Test!")

// Setup the mock for the first request (cache miss)
mockCache.On("Get", mock.Anything, "zrouter_cache:/cached-path", mock.AnythingOfType("*[]uint8")).Return(nil).Once()
mockCache.On("Set", mock.Anything, "zrouter_cache:/cached-path", cachedResponseBody, 5*time.Minute).Return(nil).Once()

// Setup the mock for the second request (cache hit)
mockCache.On("Get", mock.Anything, "zrouter_cache:/cached-path", mock.AnythingOfType("*[]uint8")).Return(nil).Run(func(args mock.Arguments) {
arg := args.Get(2).(*[]byte) // Get the argument where the cached response will be stored
*arg = cachedResponseBody // Simulate the cached response
})

// Perform the first request: the response should be generated and cached
req := httptest.NewRequest("GET", "/cached-path", nil)
rec := httptest.NewRecorder()
r.ServeHTTP(rec, req)

assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(t, "Test!", rec.Body.String())

// Verify that the cache mock was invoked correctly
mockCache.AssertExpectations(t)

// Perform the second request: the response should be served from the cache
rec2 := httptest.NewRecorder()
r.ServeHTTP(rec2, req)

assert.Equal(t, http.StatusOK, rec2.Code)
assert.Equal(t, "Test!", rec2.Body.String())

// Verify that the cache mock was invoked correctly for the second request
mockCache.AssertExpectations(t)
}

0 comments on commit 0a3a76b

Please sign in to comment.