Skip to content

Commit

Permalink
Cached responses only with 2xx status code
Browse files Browse the repository at this point in the history
  • Loading branch information
evg4b committed Jul 7, 2023
1 parent 460b93d commit dae58ca
Show file tree
Hide file tree
Showing 8 changed files with 249 additions and 41 deletions.
2 changes: 1 addition & 1 deletion internal/config/cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ type CacheConfig struct {
}

func (config *CacheConfig) Clone() *CacheConfig {
var methods []string = nil
var methods []string
if config.Methods != nil {
methods = append(methods, config.Methods...)
}
Expand Down
6 changes: 5 additions & 1 deletion internal/handler/cache/cacheable_writer.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,15 @@ func (w *CacheableResponseWriter) WriteHeader(statusCode int) {

func (w *CacheableResponseWriter) GetCachedResponse() *CachedResponse {
header := w.original.Header().Clone()
header.Del(headers.ContentLength)
clearHeader(header)

return &CachedResponse{
Code: w.code,
Body: w.buffer.Bytes(),
Header: header,
}
}

func clearHeader(header http.Header) {
header.Del(headers.ContentLength)
}
7 changes: 4 additions & 3 deletions internal/handler/cache/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,10 @@ func (m *Middleware) Wrap(next contracts.Handler) contracts.Handler {

cacheableWriter := NewCacheableWriter(writer)
next.ServeHTTP(cacheableWriter, request)

response := cacheableWriter.GetCachedResponse()
m.storage.Set(cacheKey, response, time.Hour)
if helpers.Is2xxCode(cacheableWriter.StatusCode()) {
response := cacheableWriter.GetCachedResponse()
m.storage.Set(cacheKey, response, time.Hour)
}
})
}

Expand Down
104 changes: 85 additions & 19 deletions internal/handler/cache/middleware_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@ import (
"github.com/stretchr/testify/assert"
)

func TestNewMiddleware(t *testing.T) {
func TestCacheMiddleware(t *testing.T) {
const expectedBody = "this is test"

expectedHeader := http.Header{
headers.ContentType: {"text/html; charset=iso-8859-1"},
headers.ContentEncoding: {"deflate, gzip"},
Expand All @@ -36,22 +37,58 @@ func TestNewMiddleware(t *testing.T) {

t.Run("should not call cached response just one time for", func(t *testing.T) {
tests := []struct {
name string
method string
path string
name string
method string
path string
statusCode int
}{
{name: "request in glob", method: http.MethodGet, path: "/api"},
{name: "request in glob with params", method: http.MethodGet, path: "/api?some=params"},
{name: "request in glob with other params", method: http.MethodGet, path: "/api?other=params"},
{name: "second level request in glob", method: http.MethodGet, path: "/api/comments"},
{name: "second level request in glob with params", method: http.MethodGet, path: "/api/comments?q=test"},
{name: "third level request in glob", method: http.MethodGet, path: "/api/comments/1"},
{name: "third level request in glob with params", method: http.MethodGet, path: "/api/comments/1?q=demo"},
{
name: "request in glob",
method: http.MethodGet,
path: "/api",
statusCode: http.StatusOK,
},
{
name: "request in glob with params",
method: http.MethodGet,
path: "/api?some=params",
statusCode: http.StatusOK,
},
{
name: "request in glob with other params",
method: http.MethodGet,
path: "/api?other=params",
statusCode: http.StatusOK,
},
{
name: "second level request in glob",
method: http.MethodGet,
path: "/api/comments",
statusCode: http.StatusOK,
},
{
name: "second level request in glob with params",
method: http.MethodGet,
path: "/api/comments?q=test",
statusCode: http.StatusOK,
},
{
name: "third level request in glob",
method: http.MethodGet,
path: "/api/comments/1",
statusCode: http.StatusOK,
},
{
name: "third level request in glob with params",
method: http.MethodGet,
path: "/api/comments/1?q=demo",
statusCode: http.StatusOK,
},
}
for _, testCase := range tests {
t.Run(testCase.name, func(t *testing.T) {
handler := testutils.NewCounter(func(writer contracts.ResponseWriter, request *contracts.Request) {
writer.WriteHeader(http.StatusOK)
writer.WriteHeader(testCase.statusCode)
testutils.CopyHeaders(expectedHeader, writer.Header())
sfmt.Fprintf(writer, expectedBody)
})
Expand All @@ -75,17 +112,46 @@ func TestNewMiddleware(t *testing.T) {

t.Run("should not cache response", func(t *testing.T) {
tests := []struct {
name string
method string
path string
name string
method string
path string
statusCode int
}{
{name: "with path out of glob", method: http.MethodGet, path: "/test"},
{name: "with POST method", method: http.MethodPost, path: "/api"},
{
name: "witch path out of glob",
method: http.MethodGet,
path: "/test",
statusCode: http.StatusOK,
},
{
name: "from POST method request",
method: http.MethodPost,
path: "/api",
statusCode: http.StatusOK,
},
{
name: "witch response with status code 500",
method: http.MethodGet,
path: "/api/constants",
statusCode: http.StatusInternalServerError,
},
{
name: "witch response with status code 400",
method: http.MethodGet,
path: "/api/constants",
statusCode: http.StatusBadRequest,
},
{
name: "witch response with status code 304",
method: http.MethodGet,
path: "/api/constants",
statusCode: http.StatusNotModified,
},
}
for _, testCase := range tests {
t.Run(testCase.name, func(t *testing.T) {
handler := testutils.NewCounter(func(writer contracts.ResponseWriter, request *contracts.Request) {
writer.WriteHeader(http.StatusOK)
writer.WriteHeader(testCase.statusCode)
testutils.CopyHeaders(expectedHeader, writer.Header())
sfmt.Fprintf(writer, expectedBody)
})
Expand All @@ -96,7 +162,7 @@ func TestNewMiddleware(t *testing.T) {
recorder := httptest.NewRecorder()
wrappedHandler.ServeHTTP(
contracts.WrapResponseWriter(recorder),
httptest.NewRequest(http.MethodGet, "/test", nil),
httptest.NewRequest(testCase.method, testCase.path, nil),
)
assert.Equal(t, expectedHeader, recorder.Header())
assert.Equal(t, expectedBody, testutils.ReadBody(t, recorder))
Expand Down
20 changes: 20 additions & 0 deletions internal/helpers/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,23 @@ func NormaliseRequest(request *http.Request) {
request.URL.Scheme = "http"
}
}

func Is1xxCode(code int) bool {
return 100 <= code && code < 200
}

func Is2xxCode(code int) bool {
return 200 <= code && code < 300
}

func Is3xxCode(code int) bool {
return 300 <= code && code < 400
}

func Is4xxCode(code int) bool {
return 400 <= code && code < 500
}

func Is5xxCode(code int) bool {
return 500 <= code && code < 600
}
116 changes: 116 additions & 0 deletions internal/helpers/http_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package helpers_test

import (
"crypto/tls"
"fmt"
"net/http"
"testing"

Expand Down Expand Up @@ -53,3 +54,118 @@ func TestNormaliseRequest(t *testing.T) {
assert.Equal(t, request.URL.Host, testconstants.Localhost)
})
}

func TestIs1xxCode(t *testing.T) {
cases := []struct {
code int
expected bool
}{
{http.StatusContinue, true},
{http.StatusSwitchingProtocols, true},
{http.StatusOK, false},
{http.StatusMovedPermanently, false},
{http.StatusBadRequest, false},
{http.StatusInternalServerError, false},
}

for _, testCase := range cases {
t.Run(fmt.Sprintf("shoul return %t for code %d", testCase.expected, testCase.code), func(t *testing.T) {
actual := helpers.Is1xxCode(testCase.code)
assert.Equal(t, testCase.expected, actual)
})
}
}

func TestIs2xxCode(t *testing.T) {
cases := []struct {
code int
expected bool
}{
{http.StatusOK, true},
{http.StatusCreated, true},
{http.StatusAccepted, true},
{http.StatusSwitchingProtocols, false},
{http.StatusMultipleChoices, false},
{http.StatusBadRequest, false},
{http.StatusInternalServerError, false},
}

for _, c := range cases {
t.Run(fmt.Sprintf("shoul return %t for code %d", c.expected, c.code), func(t *testing.T) {
actual := helpers.Is2xxCode(c.code)

assert.Equal(t, c.expected, actual)
})
}
}

func TestIs3xxCode(t *testing.T) {
cases := []struct {
code int
expected bool
}{
{http.StatusMultipleChoices, true},
{http.StatusMovedPermanently, true},
{http.StatusFound, true},
{http.StatusOK, false},
{http.StatusSwitchingProtocols, false},
{http.StatusBadRequest, false},
{http.StatusInternalServerError, false},
}

for _, c := range cases {
t.Run(fmt.Sprintf("shoul return %t for code %d", c.expected, c.code), func(t *testing.T) {
actual := helpers.Is3xxCode(c.code)

assert.Equal(t, c.expected, actual)
})
}
}

func TestIs4xxCode(t *testing.T) {
cases := []struct {
code int
expected bool
}{
{http.StatusBadRequest, true},
{http.StatusUnauthorized, true},
{http.StatusForbidden, true},
{http.StatusOK, false},
{http.StatusSwitchingProtocols, false},
{http.StatusMultipleChoices, false},
{http.StatusInternalServerError, false},
}

for _, c := range cases {
t.Run(fmt.Sprintf("shoul return %t for code %d", c.expected, c.code), func(t *testing.T) {
actual := helpers.Is4xxCode(c.code)

assert.Equal(t, c.expected, actual)
})
}
}

func TestIs5xxCode(t *testing.T) {
cases := []struct {
code int
expected bool
}{
{http.StatusBadRequest, false},
{http.StatusUnauthorized, false},
{http.StatusForbidden, false},
{http.StatusOK, false},
{http.StatusSwitchingProtocols, false},
{http.StatusMultipleChoices, false},
{http.StatusInternalServerError, true},
{http.StatusNetworkAuthenticationRequired, true},
{http.StatusHTTPVersionNotSupported, true},
}

for _, c := range cases {
t.Run(fmt.Sprintf("shoul return %t for code %d", c.expected, c.code), func(t *testing.T) {
actual := helpers.Is5xxCode(c.code)

assert.Equal(t, c.expected, actual)
})
}
}
27 changes: 14 additions & 13 deletions internal/log/printresponse.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package log

import (
"github.com/evg4b/uncors/internal/contracts"
"github.com/evg4b/uncors/internal/helpers"
"github.com/evg4b/uncors/internal/sfmt"
"github.com/pterm/pterm"
)
Expand All @@ -14,11 +15,7 @@ func printResponse(request *contracts.Request, statusCode int) string {
}

func getPrefixPrinter(statusCode int, text string) pterm.PrefixPrinter {
if statusCode < 100 || statusCode > 599 {
panic(sfmt.Sprintf("status code %d is not supported", statusCode))
}

if 100 <= statusCode && statusCode <= 199 {
if helpers.Is1xxCode(statusCode) {
return pterm.PrefixPrinter{
MessageStyle: &pterm.ThemeDefault.InfoMessageStyle,
Prefix: pterm.Prefix{
Expand All @@ -28,7 +25,7 @@ func getPrefixPrinter(statusCode int, text string) pterm.PrefixPrinter {
}
}

if 200 <= statusCode && statusCode <= 299 {
if helpers.Is2xxCode(statusCode) {
return pterm.PrefixPrinter{
MessageStyle: &pterm.ThemeDefault.SuccessMessageStyle,
Prefix: pterm.Prefix{
Expand All @@ -38,7 +35,7 @@ func getPrefixPrinter(statusCode int, text string) pterm.PrefixPrinter {
}
}

if 300 <= statusCode && statusCode <= 399 {
if helpers.Is3xxCode(statusCode) {
return pterm.PrefixPrinter{
MessageStyle: &pterm.ThemeDefault.WarningMessageStyle,
Prefix: pterm.Prefix{
Expand All @@ -48,11 +45,15 @@ func getPrefixPrinter(statusCode int, text string) pterm.PrefixPrinter {
}
}

return pterm.PrefixPrinter{
MessageStyle: &pterm.ThemeDefault.ErrorMessageStyle,
Prefix: pterm.Prefix{
Style: &pterm.ThemeDefault.ErrorPrefixStyle,
Text: text,
},
if helpers.Is4xxCode(statusCode) || helpers.Is5xxCode(statusCode) {
return pterm.PrefixPrinter{
MessageStyle: &pterm.ThemeDefault.ErrorMessageStyle,
Prefix: pterm.Prefix{
Style: &pterm.ThemeDefault.ErrorPrefixStyle,
Text: text,
},
}
}

panic(sfmt.Sprintf("status code %d is not supported", statusCode))
}
Loading

0 comments on commit dae58ca

Please sign in to comment.