Skip to content

Commit

Permalink
✨ Add etag (#313)
Browse files Browse the repository at this point in the history
  • Loading branch information
tosone authored Feb 21, 2024
1 parent 3aa8630 commit c738843
Show file tree
Hide file tree
Showing 3 changed files with 321 additions and 0 deletions.
1 change: 1 addition & 0 deletions pkg/cmds/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ func Serve(serverConfig ServerConfig) error {
config := ptr.To(configs.GetConfiguration())

e.Use(middleware.CORS())
e.Use(middlewares.Etag())
e.Use(echoprometheus.NewMiddleware(consts.AppName))
e.GET("/metrics", echoprometheus.NewHandler())
e.Use(middlewares.Healthz())
Expand Down
176 changes: 176 additions & 0 deletions pkg/middlewares/etag.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
// Copyright 2024 sigma
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package middlewares

import (
"bytes"
"crypto/sha256"
"encoding/hex"
"fmt"
"hash"
"hash/crc32"
"net/http"
"strconv"

"github.com/labstack/echo/v4"
"github.com/labstack/echo/v4/middleware"
)

// EtagConfig defines the config for Etag middleware.
type EtagConfig struct {
// Skipper defines a function to skip middleware.
Skipper middleware.Skipper
// Weak defines if the Etag is weak or strong.
Weak bool
// HashFn defines the hash function to use. Default is crc32q.
HashFn func(config EtagConfig) hash.Hash
}

var (
// DefaultEtagConfig is the default Etag middleware config.
DefaultEtagConfig = EtagConfig{
Skipper: middleware.DefaultSkipper,
Weak: true,
HashFn: func(config EtagConfig) hash.Hash {
if config.Weak {
const crcPol = 0xD5828281
crc32qTable := crc32.MakeTable(crcPol)
return crc32.New(crc32qTable)
}
return sha256.New()
},
}
normalizedETagName = http.CanonicalHeaderKey("Etag")
normalizedIfNoneMatchName = http.CanonicalHeaderKey("If-None-Match")
weakPrefix = "W/"
)

// Etag returns a Etag middleware.
func Etag() echo.MiddlewareFunc {
c := DefaultEtagConfig
return WithEtagConfig(c)
}

// WithEtagConfig returns a Etag middleware with config.
func WithEtagConfig(config EtagConfig) echo.MiddlewareFunc {
if config.Skipper == nil {
config.Skipper = DefaultEtagConfig.Skipper
}

return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) (err error) {
skipper := config.Skipper
if skipper == nil {
skipper = DefaultEtagConfig.Skipper
}

if skipper(c) {
return next(c)
}

// get the hash function
hashFn := config.HashFn
if hashFn == nil {
hashFn = DefaultEtagConfig.HashFn
}

originalWriter := c.Response().Writer
res := c.Response()
req := c.Request()
// ResponseWriter
hw := bufferedWriter{rw: res.Writer, hash: hashFn(config), buf: bytes.NewBuffer(nil)}
res.Writer = &hw
err = next(c)
// restore the original writer
res.Writer = originalWriter
if err != nil {
return err
}

resHeader := res.Header()

if hw.hash == nil ||
resHeader.Get(normalizedETagName) != "" ||
strconv.Itoa(hw.status)[0] != '2' ||
hw.status == http.StatusNoContent ||
hw.buf.Len() == 0 {
writeRaw(originalWriter, hw)
return
}

etag := fmt.Sprintf("%v-%v", strconv.Itoa(hw.len),
hex.EncodeToString(hw.hash.Sum(nil)))

if config.Weak {
etag = weakPrefix + etag
}

resHeader.Set(normalizedETagName, etag)

ifNoneMatch := req.Header.Get(normalizedIfNoneMatchName) // get the If-None-Match header
headerFresh := ifNoneMatch == etag || ifNoneMatch == weakPrefix+etag

if headerFresh {
originalWriter.WriteHeader(http.StatusNotModified)
originalWriter.Write(nil) // nolint: errcheck
} else {
writeRaw(originalWriter, hw)
}
return
}
}
}

// bufferedWriter is a wrapper around http.ResponseWriter that
// buffers the response and calculates the hash of the response.
type bufferedWriter struct {
rw http.ResponseWriter
hash hash.Hash
buf *bytes.Buffer
len int
status int
}

// Header returns the header map that will be sent by
func (hw bufferedWriter) Header() http.Header {
return hw.rw.Header()
}

// WriteHeader sends an HTTP response header with the provided status code.
func (hw *bufferedWriter) WriteHeader(status int) {
hw.status = status
}

// Write writes the data to the buffer to be sent as part of an HTTP reply.
func (hw *bufferedWriter) Write(b []byte) (int, error) {
if hw.status == 0 {
hw.status = http.StatusOK
}
// write to the buffer
l, err := hw.buf.Write(b)
if err != nil {
return l, err
}
// write to the hash
l, err = hw.hash.Write(b)
hw.len += l
return l, err
}

// WriteTo writes the buffered data to the underlying io.Writer.
func writeRaw(res http.ResponseWriter, hw bufferedWriter) {
res.WriteHeader(hw.status)
res.Write(hw.buf.Bytes()) // nolint: errcheck
}
144 changes: 144 additions & 0 deletions pkg/middlewares/etag_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
// Copyright 2024 sigma
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package middlewares

import (
"net/http"
"net/http/httptest"
"testing"

"github.com/labstack/echo/v4"
)

var testWeakEtag = "W/11-8dcfee46"
var testStrEtag = "11-a591a6d40bf420404a011733cfb7b190d62c65bf0bcda32b57b277d9ad9f146e"
var e *echo.Echo

func init() {
e = echo.New()
e.GET("/etag", func(c echo.Context) error {
return c.String(200, "Hello World")
}, WithEtagConfig(EtagConfig{Weak: false}))

e.GET("/etag/weak", func(c echo.Context) error {
return c.String(200, "Hello World")
}, Etag())
}

func TestStrongEtag(t *testing.T) {
// Test strong Etag
req := httptest.NewRequest(http.MethodGet, "/etag", nil)
rec := httptest.NewRecorder()

e.ServeHTTP(rec, req)

if rec.Code != http.StatusOK {
t.Errorf("Expected status code %d, got %d", http.StatusOK, rec.Code)
}

if rec.Header().Get("Etag") != testStrEtag {
t.Errorf("Expected Etag %s, got %s", testStrEtag, rec.Header().Get("Etag"))
}

if rec.Body.String() != "Hello World" {
t.Errorf("Expected body %s, got %s", "Hello World", rec.Body.String())
}

// Test If-None-Match
req = httptest.NewRequest(http.MethodGet, "/etag", nil)
req.Header.Set("If-None-Match", testStrEtag)
rec = httptest.NewRecorder()
e.ServeHTTP(rec, req)
if rec.Code != http.StatusNotModified {
t.Errorf("Expected status code %d, got %d", http.StatusNotModified, rec.Code)
}

if rec.Header().Get("Etag") != testStrEtag {
t.Errorf("Expected Etag %s, got %s", testStrEtag, rec.Header().Get("Etag"))
}

if rec.Body.String() != "" {
t.Errorf("Expected body %s, got %s", "", rec.Body.String())
}

// Test If-None-Match invalid
req = httptest.NewRequest(http.MethodGet, "/etag", nil)
req.Header.Set("If-None-Match", "invalid")
rec = httptest.NewRecorder()
e.ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Errorf("Expected status code %d, got %d", http.StatusOK, rec.Code)
}

if rec.Header().Get("Etag") != testStrEtag {
t.Errorf("Expected Etag %s, got %s", testStrEtag, rec.Header().Get("Etag"))
}

if rec.Body.String() != "Hello World" {
t.Errorf("Expected body %s, got %s", "Hello World", rec.Body.String())
}
}

func TestWeakEtag(t *testing.T) {
// Test weak Etag
req := httptest.NewRequest(http.MethodGet, "/etag/weak", nil)
rec := httptest.NewRecorder()
e.ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Errorf("Expected status code %d, got %d", http.StatusOK, rec.Code)
}

if rec.Header().Get("Etag") != testWeakEtag {
t.Errorf("Expected Etag %s, got %s", testWeakEtag, rec.Header().Get("Etag"))
}

if rec.Body.String() != "Hello World" {
t.Errorf("Expected body %s, got %s", "Hello World", rec.Body.String())
}

// Test If-None-Match weak
req = httptest.NewRequest(http.MethodGet, "/etag/weak", nil)
req.Header.Set("If-None-Match", testWeakEtag)
rec = httptest.NewRecorder()
e.ServeHTTP(rec, req)
if rec.Code != http.StatusNotModified {
t.Errorf("Expected status code %d, got %d", http.StatusNotModified, rec.Code)
}

if rec.Header().Get("Etag") != testWeakEtag {
t.Errorf("Expected Etag %s, got %s", testWeakEtag, rec.Header().Get("Etag"))
}

if rec.Body.String() != "" {
t.Errorf("Expected body %s, got %s", "", rec.Body.String())
}

// Test If-None-Match weak invalid
req = httptest.NewRequest(http.MethodGet, "/etag/weak", nil)
req.Header.Set("If-None-Match", "invalid")
rec = httptest.NewRecorder()
e.ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Errorf("Expected status code %d, got %d", http.StatusOK, rec.Code)
}

if rec.Header().Get("Etag") != testWeakEtag {
t.Errorf("Expected Etag %s, got %s", testWeakEtag, rec.Header().Get("Etag"))
}

if rec.Body.String() != "Hello World" {
t.Errorf("Expected body %s, got %s", "Hello World", rec.Body.String())
}
}

0 comments on commit c738843

Please sign in to comment.