diff --git a/.github/linters/.jscpd.json b/.github/linters/.jscpd.json index 863b17a7..612ac5ba 100644 --- a/.github/linters/.jscpd.json +++ b/.github/linters/.jscpd.json @@ -3,6 +3,7 @@ "reporters": ["html", "markdown"], "ignore": [ "**/*_test.go", + "*.md", "**/*.md" ] } diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 8383620b..b6148950 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -23,7 +23,7 @@ jobs: options: >- --entrypoint redis-server nginx: - image: fabiocicerchia/go-proxy-cache-test:nginx + image: nginx:1.21.1-alpine ports: - "40080:40080" # http - "40081:40081" # ws @@ -31,14 +31,24 @@ jobs: - "40443:40443" # https options: >- --link node + -v ${{ github.workspace }}/test/full-setup/nginx/vhost.conf:/etc/nginx/conf.d/vhost.conf + -v ${{ github.workspace }}/test/full-setup/certs:/certs node: - image: fabiocicerchia/go-proxy-cache-test:node + image: node:16.8.0-alpine3.14 ports: - "9001:9001" # ws - "9002:9002" # wss + options: >- + -w /home/node/app + -v ${{ github.workspace }}/test/full-setup:/home/node/app steps: - uses: actions/checkout@v2 + - name: Start npm + uses: docker://docker + with: + args: docker exec "${{ job.services.node.id }}" "npm start" + # Ref: https://github.community/t/services-and-volumes/16313 - name: Restart nginx uses: docker://docker diff --git a/.vscode/launch.json b/.vscode/launch.json index 5c037670..86006c4d 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -11,7 +11,7 @@ "id": "config", "type": "promptString", "description": "Configuration File", - "default": "examples/config.no-docker.yml" + "default": "examples/config.yml" } ], "configurations": [ diff --git a/Makefile b/Makefile index d5f60b2b..fb0c6f62 100644 --- a/Makefile +++ b/Makefile @@ -91,13 +91,13 @@ tlsfuzzer: ## tlsfuzzer test: test-unit test-functional test-endtoend test-ws test-http2 ## test test-unit: ## test unit - GPC_SYNC_STORING=1 go test -v -race -count=1 --tags=unit ./... + GPC_SYNC_STORING=1 TESTING=1 go test -v -race -count=1 --tags=unit ./... test-functional: ## test functional - GPC_SYNC_STORING=1 go test -v -race -count=1 --tags=functional ./... + GPC_SYNC_STORING=1 TESTING=1 go test -v -race -count=1 --tags=functional ./... test-endtoend: ## test endtoend - GPC_SYNC_STORING=1 go test -v -race -count=1 --tags=endtoend ./... + go test -v -race -count=1 --tags=endtoend ./... test-ws: ## test websocket cd test/full-setup && npm install @@ -111,7 +111,7 @@ test-http2: ## test HTTP2 fi cover: ## coverage - GPC_SYNC_STORING=1 go test -race -count=1 --tags=unit,functional -coverprofile c.out ./... + GPC_SYNC_STORING=1 TESTING=1 go test -race -count=1 --tags=unit,functional -coverprofile c.out ./... go tool cover -func=c.out go tool cover -html=c.out diff --git a/README.md b/README.md index 48dfa4f9..47eafab3 100644 --- a/README.md +++ b/README.md @@ -223,7 +223,7 @@ For examples check the relative documentation in [docs/EXAMPLES.md](https://gith ## OpenSSL This product includes software developed by the OpenSSL Project for use in the -OpenSSL Toolkit. (http://www.openssl.org/) +OpenSSL Toolkit. ([http://www.openssl.org/](http://www.openssl.org/)) ## Go Proxy Cache diff --git a/cache/cache.go b/cache/cache.go index a29ebdbd..c0e8b649 100644 --- a/cache/cache.go +++ b/cache/cache.go @@ -20,11 +20,11 @@ import ( "time" "github.com/pkg/errors" + log "github.com/sirupsen/logrus" "github.com/fabiocicerchia/go-proxy-cache/cache/engine" "github.com/fabiocicerchia/go-proxy-cache/utils" "github.com/fabiocicerchia/go-proxy-cache/utils/slice" - log "github.com/sirupsen/logrus" ) var errMissingRedisConnection = errors.New("missing redis connection") @@ -50,6 +50,7 @@ const FreshSuffix = "/fresh" // Object - Contains cache settings and current cached/cacheable object. type Object struct { + ReqID string AllowedStatuses []int AllowedMethods []string CurrentURIObject URIObj @@ -135,7 +136,9 @@ func (c Object) handleMetadata(domainID string, targetURL url.URL, expiration ti // StoreFullPage - Stores the whole page response in cache. func (c Object) StoreFullPage(expiration time.Duration) (bool, error) { if !c.IsStatusAllowed() || !c.IsMethodAllowed() || expiration < 1 { - log.Debugf( + log.WithFields(log.Fields{ + "ReqID": c.ReqID, + }).Debugf( "Not allowed to be stored. Status: %v - Method: %v - Expiration: %v", c.IsStatusAllowed(), c.IsMethodAllowed(), @@ -192,7 +195,9 @@ func (c *Object) RetrieveFullPage() error { } key := StorageKey(c.CurrentURIObject, meta) - log.Debugf("StorageKey: %s", key) + log.WithFields(log.Fields{ + "ReqID": c.ReqID, + }).Debugf("StorageKey: %s", key) var stale bool = false encoded, err := conn.Get(key + FreshSuffix) diff --git a/cache/engine/client/client.go b/cache/engine/client/client.go index 7430d97b..f281a4f0 100644 --- a/cache/engine/client/client.go +++ b/cache/engine/client/client.go @@ -31,12 +31,13 @@ var ctx = context.Background() type RedisClient struct { *goredislib.Client *redsync.Redsync - Name string - Mutex map[string]*redsync.Mutex + Name string + Mutex map[string]*redsync.Mutex + logger *log.Logger } // Connect - Connects to DB. -func Connect(connName string, config config.Cache) *RedisClient { +func Connect(connName string, config config.Cache, logger *log.Logger) *RedisClient { client := goredislib.NewClient(&goredislib.Options{ Addr: config.Host + ":" + config.Port, Password: config.Password, @@ -50,6 +51,7 @@ func Connect(connName string, config config.Cache) *RedisClient { Client: client, Redsync: rs, Mutex: make(map[string]*redsync.Mutex), + logger: logger, } return rdb @@ -61,7 +63,7 @@ func (rdb *RedisClient) Close() error { } func (rdb *RedisClient) getMutex(key string) *redsync.Mutex { - mutexname := fmt.Sprintf("mutex-%s", key) + mutexname := fmt.Sprintf("mutex-%s-%s", rdb.Name, key) if _, ok := rdb.Mutex[mutexname]; !ok { rdb.Mutex[mutexname] = rdb.Redsync.NewMutex(mutexname) } @@ -71,7 +73,7 @@ func (rdb *RedisClient) getMutex(key string) *redsync.Mutex { func (rdb *RedisClient) lock(key string) error { if err := rdb.getMutex(key).Lock(); err != nil { - log.Errorf("Lock Error on %s: %s", key, err) + rdb.logger.Errorf("Lock Error on %s: %s", key, err) return err } @@ -80,7 +82,7 @@ func (rdb *RedisClient) lock(key string) error { func (rdb *RedisClient) unlock(key string) error { if ok, err := rdb.getMutex(key).Unlock(); !ok || err != nil { - log.Errorf("Unlock Error on %s: %s", key, err) + rdb.logger.Errorf("Unlock Error on %s: %s", key, err) return err } diff --git a/cache/engine/client/client_test.go b/cache/engine/client/client_test.go index 59f7578e..6f61b852 100644 --- a/cache/engine/client/client_test.go +++ b/cache/engine/client/client_test.go @@ -15,15 +15,19 @@ import ( "testing" "time" + "github.com/stretchr/testify/assert" log "github.com/sirupsen/logrus" "github.com/fabiocicerchia/go-proxy-cache/cache/engine/client" "github.com/fabiocicerchia/go-proxy-cache/config" "github.com/fabiocicerchia/go-proxy-cache/utils" circuit_breaker "github.com/fabiocicerchia/go-proxy-cache/utils/circuit-breaker" - "github.com/stretchr/testify/assert" ) +// this is to verify any possible data race condition +const redisConnName = "testing" +const clashingKey = "test" + func initLogs() { log.SetReportCaller(true) log.SetLevel(log.DebugLevel) @@ -37,7 +41,7 @@ func initLogs() { func TestCircuitBreakerWithPingTimeout(t *testing.T) { initLogs() - config.Config = config.Configuration{ + cfg := config.Configuration{ Cache: config.Cache{ Host: utils.GetEnv("REDIS_HOST", "localhost"), Port: "6379", @@ -51,33 +55,33 @@ func TestCircuitBreakerWithPingTimeout(t *testing.T) { }, } - circuit_breaker.InitCircuitBreaker("testing", config.Config.CircuitBreaker) + circuit_breaker.InitCircuitBreaker(redisConnName, cfg.CircuitBreaker) - rdb := client.Connect("testing", config.Config.Cache) + rdb := client.Connect(redisConnName, cfg.Cache, log.StandardLogger()) - assert.Equal(t, "closed", circuit_breaker.CB("testing").State().String()) + assert.Equal(t, "closed", circuit_breaker.CB(redisConnName).State().String()) val := rdb.Ping() assert.True(t, val) - assert.Equal(t, "closed", circuit_breaker.CB("testing").State().String()) + assert.Equal(t, "closed", circuit_breaker.CB(redisConnName).State().String()) _ = rdb.Close() val = rdb.Ping() assert.False(t, val) - assert.Equal(t, "half-open", circuit_breaker.CB("testing").State().String()) + assert.Equal(t, "half-open", circuit_breaker.CB(redisConnName).State().String()) - rdb = client.Connect("testing", config.Config.Cache) + rdb = client.Connect(redisConnName, cfg.Cache, log.StandardLogger()) val = rdb.Ping() assert.True(t, val) - assert.Equal(t, "closed", circuit_breaker.CB("testing").State().String()) + assert.Equal(t, "closed", circuit_breaker.CB(redisConnName).State().String()) } func TestClose(t *testing.T) { initLogs() - config.Config = config.Configuration{ + cfg := config.Configuration{ Cache: config.Cache{ Host: utils.GetEnv("REDIS_HOST", "localhost"), Port: "6379", @@ -91,9 +95,9 @@ func TestClose(t *testing.T) { }, } - circuit_breaker.InitCircuitBreaker("testing", config.Config.CircuitBreaker) + circuit_breaker.InitCircuitBreaker(redisConnName, cfg.CircuitBreaker) - rdb := client.Connect("testing", config.Config.Cache) + rdb := client.Connect(redisConnName, cfg.Cache, log.StandardLogger()) assert.True(t, rdb.Ping()) @@ -105,7 +109,7 @@ func TestClose(t *testing.T) { func TestSetGet(t *testing.T) { initLogs() - config.Config = config.Configuration{ + cfg := config.Configuration{ Cache: config.Cache{ Host: utils.GetEnv("REDIS_HOST", "localhost"), Port: "6379", @@ -119,15 +123,15 @@ func TestSetGet(t *testing.T) { }, } - circuit_breaker.InitCircuitBreaker("testing", config.Config.CircuitBreaker) + circuit_breaker.InitCircuitBreaker(redisConnName, cfg.CircuitBreaker) - rdb := client.Connect("testing", config.Config.Cache) + rdb := client.Connect(redisConnName, cfg.Cache, log.StandardLogger()) - done, err := rdb.Set("test", "sample", 0) + done, err := rdb.Set(clashingKey, "sample", 0) assert.True(t, done) assert.Nil(t, err) - value, err := rdb.Get("test") + value, err := rdb.Get(clashingKey) assert.Equal(t, "sample", value) assert.Nil(t, err) } @@ -135,7 +139,7 @@ func TestSetGet(t *testing.T) { func TestSetGetWithExpiration(t *testing.T) { initLogs() - config.Config = config.Configuration{ + cfg := config.Configuration{ Cache: config.Cache{ Host: utils.GetEnv("REDIS_HOST", "localhost"), Port: "6379", @@ -149,17 +153,17 @@ func TestSetGetWithExpiration(t *testing.T) { }, } - circuit_breaker.InitCircuitBreaker("testing", config.Config.CircuitBreaker) + circuit_breaker.InitCircuitBreaker(redisConnName, cfg.CircuitBreaker) - rdb := client.Connect("testing", config.Config.Cache) + rdb := client.Connect(redisConnName, cfg.Cache, log.StandardLogger()) - done, err := rdb.Set("test", "sample", 1*time.Millisecond) + done, err := rdb.Set(clashingKey, "sample", 1*time.Millisecond) assert.True(t, done) assert.Nil(t, err) - time.Sleep(10 * time.Millisecond) + time.Sleep(10 * time.Millisecond) // let it expire in Redis - value, err := rdb.Get("test") + value, err := rdb.Get(clashingKey) assert.Equal(t, "", value) assert.Nil(t, err) } @@ -167,7 +171,7 @@ func TestSetGetWithExpiration(t *testing.T) { func TestDel(t *testing.T) { initLogs() - config.Config = config.Configuration{ + cfg := config.Configuration{ Cache: config.Cache{ Host: utils.GetEnv("REDIS_HOST", "localhost"), Port: "6379", @@ -181,22 +185,22 @@ func TestDel(t *testing.T) { }, } - circuit_breaker.InitCircuitBreaker("testing", config.Config.CircuitBreaker) + circuit_breaker.InitCircuitBreaker(redisConnName, cfg.CircuitBreaker) - rdb := client.Connect("testing", config.Config.Cache) + rdb := client.Connect(redisConnName, cfg.Cache, log.StandardLogger()) - done, err := rdb.Set("test", "sample", 0) + done, err := rdb.Set(clashingKey, "sample", 0) assert.True(t, done) assert.Nil(t, err) - value, err := rdb.Get("test") + value, err := rdb.Get(clashingKey) assert.Equal(t, "sample", value) assert.Nil(t, err) - err = rdb.Del("test") + err = rdb.Del(clashingKey) assert.Nil(t, err) - value, err = rdb.Get("test") + value, err = rdb.Get(clashingKey) assert.Equal(t, "", value) assert.Nil(t, err) } @@ -204,7 +208,7 @@ func TestDel(t *testing.T) { func TestExpire(t *testing.T) { initLogs() - config.Config = config.Configuration{ + cfg := config.Configuration{ Cache: config.Cache{ Host: utils.GetEnv("REDIS_HOST", "localhost"), Port: "6379", @@ -218,20 +222,21 @@ func TestExpire(t *testing.T) { }, } - circuit_breaker.InitCircuitBreaker("testing", config.Config.CircuitBreaker) + circuit_breaker.InitCircuitBreaker(redisConnName, cfg.CircuitBreaker) - rdb := client.Connect("testing", config.Config.Cache) + rdb := client.Connect(redisConnName, cfg.Cache, log.StandardLogger()) - done, err := rdb.Set("test", "sample", 0) + done, err := rdb.Set(clashingKey, "sample", 0) assert.True(t, done) assert.Nil(t, err) - err = rdb.Expire("test", 1*time.Second) + // redis: commands.go:36: specified duration is 100ms, but minimal supported value is 1s - truncating to 1s + err = rdb.Expire(clashingKey, 100*time.Millisecond) assert.Nil(t, err) - time.Sleep(1500 * time.Millisecond) + time.Sleep(2 * time.Second) - value, err := rdb.Get("test") + value, err := rdb.Get(clashingKey) assert.Equal(t, "", value) assert.Nil(t, err) } @@ -239,7 +244,7 @@ func TestExpire(t *testing.T) { func TestPushList(t *testing.T) { initLogs() - config.Config = config.Configuration{ + cfg := config.Configuration{ Cache: config.Cache{ Host: utils.GetEnv("REDIS_HOST", "localhost"), Port: "6379", @@ -253,22 +258,78 @@ func TestPushList(t *testing.T) { }, } - circuit_breaker.InitCircuitBreaker("testing", config.Config.CircuitBreaker) + circuit_breaker.InitCircuitBreaker(redisConnName, cfg.CircuitBreaker) - rdb := client.Connect("testing", config.Config.Cache) + rdb := client.Connect(redisConnName, cfg.Cache, log.StandardLogger()) - err := rdb.Push("test", []string{"a", "b", "c"}) + err := rdb.Push(clashingKey, []string{"a", "b", "c"}) assert.Nil(t, err) - value, err := rdb.List("test") + value, err := rdb.List(clashingKey) assert.Equal(t, []string{"a", "b", "c"}, value) assert.Nil(t, err) } +func TestDelWildcardNoMatch(t *testing.T) { + initLogs() + + cfg := config.Configuration{ + Cache: config.Cache{ + Host: utils.GetEnv("REDIS_HOST", "localhost"), + Port: "6379", + DB: 0, + }, + CircuitBreaker: circuit_breaker.CircuitBreaker{ + Threshold: 2, // after 2nd request, if meet FailureRate goes open. + FailureRate: 0.5, // 1 out of 2 fails, or more + Interval: 0, // doesn't clears counts + Timeout: time.Duration(1), // clears state immediately + }, + } + + circuit_breaker.InitCircuitBreaker(redisConnName, cfg.CircuitBreaker) + + rdb := client.Connect(redisConnName, cfg.Cache, log.StandardLogger()) + + done, err := rdb.Set("test_1", "sample", 0) + assert.True(t, done) + assert.Nil(t, err) + done, err = rdb.Set("test_2", "sample", 0) + assert.True(t, done) + assert.Nil(t, err) + done, err = rdb.Set("test_3", "sample", 0) + assert.True(t, done) + assert.Nil(t, err) + + value, err := rdb.Get("test_1") + assert.Equal(t, "sample", value) + assert.Nil(t, err) + value, err = rdb.Get("test_2") + assert.Equal(t, "sample", value) + assert.Nil(t, err) + value, err = rdb.Get("test_3") + assert.Equal(t, "sample", value) + assert.Nil(t, err) + + len, err := rdb.DelWildcard("missing_*") + assert.Equal(t, 0, len) + assert.Nil(t, err) + + value, err = rdb.Get("test_1") + assert.Equal(t, "sample", value) + assert.Nil(t, err) + value, err = rdb.Get("test_2") + assert.Equal(t, "sample", value) + assert.Nil(t, err) + value, err = rdb.Get("test_3") + assert.Equal(t, "sample", value) + assert.Nil(t, err) +} + func TestDelWildcard(t *testing.T) { initLogs() - config.Config = config.Configuration{ + cfg := config.Configuration{ Cache: config.Cache{ Host: utils.GetEnv("REDIS_HOST", "localhost"), Port: "6379", @@ -282,9 +343,9 @@ func TestDelWildcard(t *testing.T) { }, } - circuit_breaker.InitCircuitBreaker("testing", config.Config.CircuitBreaker) + circuit_breaker.InitCircuitBreaker(redisConnName, cfg.CircuitBreaker) - rdb := client.Connect("testing", config.Config.Cache) + rdb := client.Connect(redisConnName, cfg.Cache, log.StandardLogger()) done, err := rdb.Set("test_1", "sample", 0) assert.True(t, done) @@ -324,7 +385,7 @@ func TestDelWildcard(t *testing.T) { func TestPurgeAll(t *testing.T) { initLogs() - config.Config = config.Configuration{ + cfg := config.Configuration{ Cache: config.Cache{ Host: utils.GetEnv("REDIS_HOST", "localhost"), Port: "6379", @@ -338,9 +399,9 @@ func TestPurgeAll(t *testing.T) { }, } - circuit_breaker.InitCircuitBreaker("testing", config.Config.CircuitBreaker) + circuit_breaker.InitCircuitBreaker(redisConnName, cfg.CircuitBreaker) - rdb := client.Connect("testing", config.Config.Cache) + rdb := client.Connect(redisConnName, cfg.Cache, log.StandardLogger()) done, err := rdb.Set("test_1", "sample", 0) assert.True(t, done) @@ -380,7 +441,7 @@ func TestPurgeAll(t *testing.T) { func TestEncodeDecode(t *testing.T) { initLogs() - config.Config = config.Configuration{ + cfg := config.Configuration{ Cache: config.Cache{ Host: utils.GetEnv("REDIS_HOST", "localhost"), Port: "6379", @@ -394,9 +455,9 @@ func TestEncodeDecode(t *testing.T) { }, } - circuit_breaker.InitCircuitBreaker("testing", config.Config.CircuitBreaker) + circuit_breaker.InitCircuitBreaker(redisConnName, cfg.CircuitBreaker) - rdb := client.Connect("testing", config.Config.Cache) + rdb := client.Connect(redisConnName, cfg.Cache, log.StandardLogger()) str := []byte("test string") var decoded []byte diff --git a/cache/engine/redis.go b/cache/engine/redis.go index f5aa4d17..5b435830 100644 --- a/cache/engine/redis.go +++ b/cache/engine/redis.go @@ -30,11 +30,11 @@ func GetConn(connName string) *client.RedisClient { } // InitConn - Initialises the Redis connection. -func InitConn(connName string, config config.Cache) { +func InitConn(connName string, config config.Cache, logger *log.Logger) { if rdb == nil { rdb = make(map[string]*client.RedisClient) } - log.Debugf("New redis connection for %s", connName) - rdb[connName] = client.Connect(connName, config) + logger.Debugf("New redis connection for %s", connName) + rdb[connName] = client.Connect(connName, config, logger) } diff --git a/config/config.go b/config/config.go index 4506d471..ef865186 100644 --- a/config/config.go +++ b/config/config.go @@ -24,8 +24,8 @@ import ( "gopkg.in/yaml.v2" "github.com/fabiocicerchia/go-proxy-cache/utils" + "github.com/fabiocicerchia/go-proxy-cache/utils/scheme" "github.com/fabiocicerchia/go-proxy-cache/utils/slice" - utilsString "github.com/fabiocicerchia/go-proxy-cache/utils/string" ) // PasswordOmittedValue - Replacement value when showing passwords in configuration. @@ -59,7 +59,7 @@ func getFromYaml(file string) (Configuration, error) { return YamlConfig, err } - YamlConfig.Server.Upstream.Scheme = utilsString.NormalizeScheme(YamlConfig.Server.Upstream.Scheme) + YamlConfig.Server.Upstream.Scheme = scheme.NormalizeScheme(YamlConfig.Server.Upstream.Scheme) return YamlConfig, err } @@ -244,47 +244,49 @@ func getSliceFromMap(domains map[string]DomainSet) []DomainSet { } // DomainConf - Returns the configuration for the requested domain (Global Access). -func DomainConf(domain string, scheme string) *Configuration { +func DomainConf(domain string, scheme string) (Configuration, bool) { return Config.DomainConf(domain, scheme) } // DomainConf - Returns the configuration for the requested domain. -func (c Configuration) DomainConf(domain string, scheme string) *Configuration { +func (c Configuration) DomainConf(domain string, scheme string) (Configuration, bool) { + var found bool + // Memoization if c.domainsCache == nil { - c.domainsCache = make(map[string]*Configuration) + c.domainsCache = make(map[string]Configuration) } keyCache := fmt.Sprintf("%s%s%s", domain, utils.StringSeparatorOne, scheme) if val, ok := c.domainsCache[keyCache]; ok { log.Debugf("Cached configuration for %s", keyCache) - return val + return val, true } - c.domainsCache[keyCache] = c.domainConfLookup(utils.StripPort(domain), scheme) + c.domainsCache[keyCache], found = c.domainConfLookup(utils.StripPort(domain), scheme) - return c.domainsCache[keyCache] + return c.domainsCache[keyCache], found } -func (c Configuration) domainConfLookup(domain string, scheme string) *Configuration { +func (c Configuration) domainConfLookup(domain string, scheme string) (Configuration, bool) { // First round: host & scheme for _, v := range c.Domains { if v.Server.Upstream.Host == domain && v.Server.Upstream.Scheme == scheme { - return &v + return v, true } } // Second round: host for _, v := range c.Domains { if v.Server.Upstream.Host == domain { - return &v + return v, true } } // Third round: global if c.Server.Upstream.Host == domain { - return &c + return c, true } - return nil + return Configuration{}, false } diff --git a/config/model.go b/config/model.go index ff979bf9..537ffaee 100644 --- a/config/model.go +++ b/config/model.go @@ -55,7 +55,7 @@ type Configuration struct { CircuitBreaker circuitbreaker.CircuitBreaker `yaml:"circuit_breaker"` Domains Domains `yaml:"domains"` Log Log `yaml:"log"` - domainsCache map[string]*Configuration + domainsCache map[string]Configuration } // Domains - Overrides per domain. diff --git a/docs/DEV.md b/docs/DEV.md index 67a5243b..9849def0 100644 --- a/docs/DEV.md +++ b/docs/DEV.md @@ -20,7 +20,7 @@ $ docker-compose up ## Test -**NOTE:** If you use docker please use `config.yml` otherwise `config.no-docker.yml`. The port will be different from host and container, this will address the issue. +**NOTE:** In order to have a fully working environment you need to put in the host file `127.0.0.1 nginx`. ```console $ make test diff --git a/go.mod b/go.mod index 08efb70a..5b4e97dd 100644 --- a/go.mod +++ b/go.mod @@ -3,10 +3,10 @@ module github.com/fabiocicerchia/go-proxy-cache go 1.14 require ( - github.com/NYTimes/gziphandler v1.1.1 - github.com/go-http-utils/etag v0.0.0-20161124023236-513ea8f21eb1 - github.com/go-http-utils/fresh v0.0.0-20161124030543-7231e26a4b27 // indirect - github.com/go-http-utils/headers v0.0.0-20181008091004-fed159eddc2a // indirect + github.com/NYTimes/gziphandler v1.1.1 // indirect + github.com/go-http-utils/etag v0.0.0-20161124023236-513ea8f21eb1 // indirect + github.com/go-http-utils/fresh v0.0.0-20161124030543-7231e26a4b27 + github.com/go-http-utils/headers v0.0.0-20181008091004-fed159eddc2a github.com/go-redis/redis/v8 v8.3.3 github.com/go-redsync/redsync/v4 v4.0.4 github.com/gorilla/websocket v1.4.2 @@ -14,13 +14,15 @@ require ( github.com/kelseyhightower/envconfig v1.4.0 github.com/kr/pretty v0.1.0 // indirect github.com/pkg/errors v0.9.1 + github.com/rs/xid v1.3.0 github.com/sdeoras/dispatcher v1.0.2 github.com/sirupsen/logrus v1.3.0 + github.com/soheilhy/cmux v0.1.5 // indirect github.com/sony/gobreaker v0.4.1 github.com/stretchr/testify v1.6.1 github.com/ugorji/go/codec v1.2.0 github.com/yhat/wsutil v0.0.0-20170731153501-1d66fa95c997 golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9 - golang.org/x/net v0.0.0-20201006153459-a7d1128ccaa0 + golang.org/x/net v0.0.0-20201202161906-c7110b5ffcbb gopkg.in/yaml.v2 v2.3.0 ) diff --git a/go.sum b/go.sum index 9275ce59..d732efec 100644 --- a/go.sum +++ b/go.sum @@ -78,10 +78,14 @@ github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/rs/xid v1.3.0 h1:6NjYksEUlhurdVehpc7S7dk6DAmcKv8V9gG0FsVN2U4= +github.com/rs/xid v1.3.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg= github.com/sdeoras/dispatcher v1.0.2 h1:2PIohOpkza/bzcTjEMFrKC/PmfPkNgz+gCWdUh8LxUA= github.com/sdeoras/dispatcher v1.0.2/go.mod h1:WXv44FUh84I6S5lw9CgG14ClHMZQuPfLTtkLK1bhImE= github.com/sirupsen/logrus v1.3.0 h1:hI/7Q+DtNZ2kINb6qt/lS+IyXnHQe9e90POfeewL/ME= github.com/sirupsen/logrus v1.3.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo= +github.com/soheilhy/cmux v0.1.5 h1:jjzc5WVemNEDTLwv9tlmemhC73tI08BNOIGwBOo10Js= +github.com/soheilhy/cmux v0.1.5/go.mod h1:T7TcVDs9LWfQgPlPsdngu6I6QIoyIFZDDC6sNE1GqG0= github.com/sony/gobreaker v0.4.1 h1:oMnRNZXX5j85zso6xCPRNPtmAycat+WcoKbklScLDgQ= github.com/sony/gobreaker v0.4.1/go.mod h1:ZKptC7FHNvhBz7dN2LGjPVBz2sZJmc0/PkyDJOjmxWY= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= @@ -111,6 +115,8 @@ golang.org/x/net v0.0.0-20190923162816-aa69164e4478/go.mod h1:z5CRVTTTmAJ677TzLL golang.org/x/net v0.0.0-20200520004742-59133d7f0dd7/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A= golang.org/x/net v0.0.0-20201006153459-a7d1128ccaa0 h1:wBouT66WTYFXdxfVdz9sVWARVd/2vfGcmI45D2gj45M= golang.org/x/net v0.0.0-20201006153459-a7d1128ccaa0/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= +golang.org/x/net v0.0.0-20201202161906-c7110b5ffcbb h1:eBmm0M9fYhWpKZLjQUUKka/LtIxf46G4fxeEz5KJr9U= +golang.org/x/net v0.0.0-20201202161906-c7110b5ffcbb/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= diff --git a/server/balancer/roundrobin/roundrobin_test.go b/server/balancer/roundrobin/roundrobin_test.go index d62257d8..1eb2d17f 100644 --- a/server/balancer/roundrobin/roundrobin_test.go +++ b/server/balancer/roundrobin/roundrobin_test.go @@ -16,9 +16,9 @@ import ( "testing" log "github.com/sirupsen/logrus" + "github.com/stretchr/testify/assert" "github.com/fabiocicerchia/go-proxy-cache/server/balancer/roundrobin" - "github.com/stretchr/testify/assert" ) func initLogs() { diff --git a/server/handler/connect_test.go b/server/handler/connect_test.go index 3ddb2700..ba0f3c3f 100644 --- a/server/handler/connect_test.go +++ b/server/handler/connect_test.go @@ -18,13 +18,15 @@ import ( "testing" "time" + "github.com/stretchr/testify/assert" + log "github.com/sirupsen/logrus" + "github.com/fabiocicerchia/go-proxy-cache/cache/engine" "github.com/fabiocicerchia/go-proxy-cache/config" "github.com/fabiocicerchia/go-proxy-cache/server/balancer" "github.com/fabiocicerchia/go-proxy-cache/server/handler" "github.com/fabiocicerchia/go-proxy-cache/utils" circuit_breaker "github.com/fabiocicerchia/go-proxy-cache/utils/circuit-breaker" - "github.com/stretchr/testify/assert" ) func TestEndToEndCallConnect(t *testing.T) { @@ -56,7 +58,7 @@ func TestEndToEndCallConnect(t *testing.T) { domainID := config.Config.Server.Upstream.GetDomainID() balancer.InitRoundRobin(domainID, config.Config.Server.Upstream.Endpoints) circuit_breaker.InitCircuitBreaker(domainID, config.Config.CircuitBreaker) - engine.InitConn(domainID, config.Config.Cache) + engine.InitConn(domainID, config.Config.Cache, log.StandardLogger()) req, err := http.NewRequest("CONNECT", "/", nil) req.URL.Scheme = config.Config.Server.Upstream.Scheme @@ -66,7 +68,7 @@ func TestEndToEndCallConnect(t *testing.T) { assert.Nil(t, err) rr := httptest.NewRecorder() - h := http.HandlerFunc(handler.HandleRequest(config.Config)) + h := http.HandlerFunc(handler.HandleRequest) h.ServeHTTP(rr, req) diff --git a/server/handler/etag.go b/server/handler/etag.go new file mode 100644 index 00000000..89a9f9b1 --- /dev/null +++ b/server/handler/etag.go @@ -0,0 +1,49 @@ +package handler + +// __ +// .-----.-----.______.-----.----.-----.--.--.--.--.______.----.---.-.----| |--.-----. +// | _ | _ |______| _ | _| _ |_ _| | |______| __| _ | __| | -__| +// |___ |_____| | __|__| |_____|__.__|___ | |____|___._|____|__|__|_____| +// |_____| |__| |_____| +// +// Copyright (c) 2020 Fabio Cicerchia. https://fabiocicerchia.it. MIT License +// Repo: https://github.com/fabiocicerchia/go-proxy-cache + +import ( + "net/http/httputil" + + "github.com/go-http-utils/fresh" + "github.com/yhat/wsutil" +) + +const HttpVersion2 = 2 + +// HandleRequestWithETag - Add HTTP header ETag only on HTTP(S) requests. +func (rc RequestCall) GetResponseWithETag(proxy *httputil.ReverseProxy) (serveNotModified bool) { + // Start buffering the response. + proxy.ServeHTTP(rc.Response, &rc.Request) + + // ETag wrapper doesn't work well with WebSocket and HTTP/2. + if wsutil.IsWebSocketRequest(&rc.Request) || rc.Request.ProtoMajor == HttpVersion2 { + rc.GetLogger().Info("Current request doesn't support ETag.") + + // Serve existing response. + return false + } + + // Serve existing response. + if rc.Response.MustServeOriginalResponse(&rc.Request) { + rc.GetLogger().Info("Serving original response as cannot be handled with ETag.") + return false + } + + rc.Response.SetETag(false) + + // Send 304 Not Modified. + if fresh.IsFresh(rc.Request.Header, rc.Response.Header()) { + return true + } + + // Serve response with ETag header. + return false +} diff --git a/server/handler/etag_test.go b/server/handler/etag_test.go new file mode 100644 index 00000000..6d161f8d --- /dev/null +++ b/server/handler/etag_test.go @@ -0,0 +1,155 @@ +// +build all functional + +package handler_test + +// __ +// .-----.-----.______.-----.----.-----.--.--.--.--.______.----.---.-.----| |--.-----. +// | _ | _ |______| _ | _| _ |_ _| | |______| __| _ | __| | -__| +// |___ |_____| | __|__| |_____|__.__|___ | |____|___._|____|__|__|_____| +// |_____| |__| |_____| +// +// Copyright (c) 2020 Fabio Cicerchia. https://fabiocicerchia.it. MIT License +// Repo: https://github.com/fabiocicerchia/go-proxy-cache + +import ( + "crypto/tls" + "net/http" + "net/http/httptest" + "net/http/httputil" + "net/url" + "regexp" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/fabiocicerchia/go-proxy-cache/server/handler" + "github.com/fabiocicerchia/go-proxy-cache/server/response" +) + +func TestGetResponseWithETagWithHTTP2(t *testing.T) { + initLogs() + + reqMock := http.Request{ + Proto: "HTTPS", + ProtoMajor: 2, + Method: "GET", + RemoteAddr: "127.0.0.1", + URL: &url.URL{Path: "/path/to/file"}, + Header: http.Header{ + "Host": []string{"localhost"}, + }, + } + + proxyUrl := &url.URL{Scheme: "https", Host: "www.w3.org"} + proxy := httputil.NewSingleHostReverseProxy(proxyUrl) + + reqID := "TestGetResponseWithETag" + rr := httptest.NewRecorder() + rcMock := handler.RequestCall{ + ReqID: reqID, + Response: response.NewLoggedResponseWriter(rr, reqID), + Request: reqMock, + } + + serveNotModified := rcMock.GetResponseWithETag(proxy) + + assert.False(t, serveNotModified) +} + +func TestGetResponseWithETagWithExistingETag(t *testing.T) { + initLogs() + + // Page with actual ETag + reqMock := http.Request{ + Proto: "HTTPS", + ProtoMajor: 2, + Method: "GET", + RemoteAddr: "127.0.0.1", + URL: &url.URL{Scheme: "https", Host: "www.w3.org", Path: "/"}, + Header: http.Header{ + "Host": []string{"www.w3.org"}, + }, + } + + proxyUrl := &url.URL{Scheme: "https", Host: "www.w3.org"} + proxy := httputil.NewSingleHostReverseProxy(proxyUrl) + + reqID := "TestGetResponseWithETag" + rr := httptest.NewRecorder() + rr.Header().Add("ETag", "TestGetResponseWithETagWithExistingETag") + rcMock := handler.RequestCall{ + ReqID: reqID, + Response: response.NewLoggedResponseWriter(rr, reqID), + Request: reqMock, + } + + serveNotModified := rcMock.GetResponseWithETag(proxy) + + assert.False(t, serveNotModified) + assert.Equal(t, "TestGetResponseWithETagWithExistingETag", rr.Header().Get("ETag")) +} + +func TestGetResponseWithETagGeneratedInternally(t *testing.T) { + initLogs() + + // Page without actual ETag + reqMock := http.Request{ + Proto: "HTTPS", + Method: "GET", + RemoteAddr: "127.0.0.1", + URL: &url.URL{Scheme: "https", Host: "www.google.com", Path: "/"}, + Header: http.Header{ + "Host": []string{"www.google.com"}, + }, + } + reqMock.TLS = &tls.ConnectionState{} // mock a fake https + + proxyUrl := &url.URL{Scheme: "https", Host: "www.google.com"} + proxy := httputil.NewSingleHostReverseProxy(proxyUrl) + + reqID := "TestGetResponseWithETagGeneratedInternally" + rr := httptest.NewRecorder() + rcMock := handler.RequestCall{ + ReqID: reqID, + Response: response.NewLoggedResponseWriter(rr, reqID), + Request: reqMock, + } + + serveNotModified := rcMock.GetResponseWithETag(proxy) + + assert.False(t, serveNotModified) + assert.Regexp(t, regexp.MustCompile(`^\"[0-9]+-[0-9a-f]{40}\"$`), rr.Header().Get("ETag")) +} + +func TestGetResponseWithETagGeneratedInternallyAndFresh(t *testing.T) { + initLogs() + + // Page without actual ETag + reqMock := http.Request{ + Proto: "HTTPS", + Method: "GET", + RemoteAddr: "127.0.0.1", + URL: &url.URL{Scheme: "https", Host: "www.google.com", Path: "/"}, + Header: http.Header{ + "Host": []string{"www.google.com"}, + "If-None-Match": []string{"*"}, + }, + } + reqMock.TLS = &tls.ConnectionState{} // mock a fake https + + proxyUrl := &url.URL{Scheme: "https", Host: "www.google.com"} + proxy := httputil.NewSingleHostReverseProxy(proxyUrl) + + reqID := "TestGetResponseWithETagGeneratedInternally" + rr := httptest.NewRecorder() + rcMock := handler.RequestCall{ + ReqID: reqID, + Response: response.NewLoggedResponseWriter(rr, reqID), + Request: reqMock, + } + + serveNotModified := rcMock.GetResponseWithETag(proxy) + + assert.True(t, serveNotModified) + assert.Regexp(t, regexp.MustCompile(`^\"[0-9]+-[0-9a-f]{40}\"$`), rr.Header().Get("ETag")) +} diff --git a/server/middleware.go b/server/handler/gzip.go similarity index 52% rename from server/middleware.go rename to server/handler/gzip.go index a1064856..efc99882 100644 --- a/server/middleware.go +++ b/server/handler/gzip.go @@ -1,4 +1,4 @@ -package server +package handler // __ // .-----.-----.______.-----.----.-----.--.--.--.--.______.----.---.-.----| |--.-----. @@ -11,20 +11,18 @@ package server import ( "net/http" + "strings" - "github.com/go-http-utils/etag" - "github.com/yhat/wsutil" + "github.com/go-http-utils/headers" + + "github.com/fabiocicerchia/go-proxy-cache/server/response" ) -// ConditionalETag - Add HTTP header ETag only on HTTP(S) requests. -func ConditionalETag(h http.Handler) http.Handler { - return http.HandlerFunc(func(res http.ResponseWriter, req *http.Request) { - // ETag wrapper doesn't work well with WebSocket and HTTP/2. - if !wsutil.IsWebSocketRequest(req) && req.ProtoMajor != 2 { - etagHandler := etag.Handler(h, false) - etagHandler.ServeHTTP(res, req) - } else { - h.ServeHTTP(res, req) - } - }) +// HandleRequestWithETag - Add HTTP header ETag only on HTTP(S) requests. +func WrapResponseForGZip(res *response.LoggedResponseWriter, req *http.Request) { + if !strings.Contains(req.Header.Get(headers.AcceptEncoding), "gzip") { + return + } + + res.Header().Set(headers.ContentEncoding, "gzip") } diff --git a/server/handler/gzip_test.go b/server/handler/gzip_test.go new file mode 100644 index 00000000..5454ffa6 --- /dev/null +++ b/server/handler/gzip_test.go @@ -0,0 +1,70 @@ +package handler_test + +// __ +// .-----.-----.______.-----.----.-----.--.--.--.--.______.----.---.-.----| |--.-----. +// | _ | _ |______| _ | _| _ |_ _| | |______| __| _ | __| | -__| +// |___ |_____| | __|__| |_____|__.__|___ | |____|___._|____|__|__|_____| +// |_____| |__| |_____| +// +// Copyright (c) 2020 Fabio Cicerchia. https://fabiocicerchia.it. MIT License +// Repo: https://github.com/fabiocicerchia/go-proxy-cache + +import ( + "net/http" + "net/http/httptest" + "net/url" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/fabiocicerchia/go-proxy-cache/server/handler" + "github.com/fabiocicerchia/go-proxy-cache/server/response" +) + +// HandleRequestWithETag - Add HTTP header ETag only on HTTP(S) requests. +func TestWrapResponseForGZipWhenNoAcceptEncoding(t *testing.T) { + initLogs() + + reqMock := http.Request{ + Proto: "HTTPS", + Method: "GET", + RemoteAddr: "127.0.0.1", + URL: &url.URL{Path: "/path/to/file"}, + Header: http.Header{ + "Host": []string{"localhost"}, + "Accept-Encoding": []string{""}, + }, + } + + reqID := "TestWrapResponseForGZipWhenNoAcceptEncoding" + rr := httptest.NewRecorder() + res := response.NewLoggedResponseWriter(rr, reqID) + + handler.WrapResponseForGZip(res, &reqMock) + + assert.Equal(t, "", rr.Header().Get("Content-Encoding")) +} + +// HandleRequestWithETag - Add HTTP header ETag only on HTTP(S) requests. +func TestWrapResponseForGZipWhenAcceptEncodingGZip(t *testing.T) { + initLogs() + + reqMock := http.Request{ + Proto: "HTTPS", + Method: "GET", + RemoteAddr: "127.0.0.1", + URL: &url.URL{Path: "/path/to/file"}, + Header: http.Header{ + "Host": []string{"localhost"}, + "Accept-Encoding": []string{"gzip"}, + }, + } + + reqID := "TestWrapResponseForGZipWhenNoAcceptEncoding" + rr := httptest.NewRecorder() + res := response.NewLoggedResponseWriter(rr, reqID) + + handler.WrapResponseForGZip(res, &reqMock) + + assert.Equal(t, "gzip", rr.Header().Get("Content-Encoding")) +} diff --git a/server/handler/handler.go b/server/handler/handler.go index 65d59d32..bfa5ea18 100644 --- a/server/handler/handler.go +++ b/server/handler/handler.go @@ -10,61 +10,76 @@ package handler // Repo: https://github.com/fabiocicerchia/go-proxy-cache import ( + "fmt" "net/http" + "github.com/rs/xid" + "github.com/fabiocicerchia/go-proxy-cache/config" "github.com/fabiocicerchia/go-proxy-cache/server/logger" "github.com/fabiocicerchia/go-proxy-cache/server/response" - log "github.com/sirupsen/logrus" ) +// HttpMethodPurge - PURGE method. +const HttpMethodPurge = "PURGE" + // HandleRequest - Handles the entrypoint and directs the traffic to the right handler. -func HandleRequest(cfg config.Configuration) func(res http.ResponseWriter, req *http.Request) { - return func(res http.ResponseWriter, req *http.Request) { - rc := initRequestParams(res, req, cfg) - if rc.DomainConfig == nil { - return - } +func HandleRequest(res http.ResponseWriter, req *http.Request) { + rc, err := initRequestParams(res, req) + if err != nil { + rc.GetLogger().Errorf(err.Error()) + return + } - if rc.GetScheme() == SchemeHTTP && rc.DomainConfig.Server.Upstream.HTTP2HTTPS { - rc.RedirectToHTTPS() - return + if rc.Request.Method == http.MethodConnect { + if enableLoggingRequest { + logger.LogRequest(rc.Request, *rc.Response, rc.ReqID, false, "-") } - if rc.Request.Method == "PURGE" { - rc.HandlePurge() - return - } + rc.Response.ForceWriteHeader(http.StatusMethodNotAllowed) + return + } - if rc.Request.Method == http.MethodConnect { - rc.Response.WriteHeader(http.StatusMethodNotAllowed) - return - } + if rc.GetScheme() == SchemeHTTP && rc.DomainConfig.Server.Upstream.HTTP2HTTPS { + rc.RedirectToHTTPS() + return + } - if rc.IsWebSocket() { - rc.HandleWSRequestAndProxy() - } else { - rc.HandleHTTPRequestAndProxy() - } + if rc.Request.Method == HttpMethodPurge { + rc.HandlePurge() + return + } + + if rc.IsWebSocket() { + rc.HandleWSRequestAndProxy() + } else { + rc.HandleHTTPRequestAndProxy() } } -func initRequestParams(res http.ResponseWriter, req *http.Request, cfg config.Configuration) RequestCall { +func initRequestParams(res http.ResponseWriter, req *http.Request) (RequestCall, error) { + var configFound bool + + reqID := xid.New().String() rc := RequestCall{ - Response: response.NewLoggedResponseWriter(res), - Request: req, + ReqID: reqID, + Response: response.NewLoggedResponseWriter(res, reqID), + Request: *req, } listeningPort := getListeningPort(req.Context()) - rc.DomainConfig = cfg.DomainConf(rc.GetHostname(), rc.GetScheme()) - if rc.DomainConfig == nil || !isLegitPort(rc.DomainConfig.Server.Port, listeningPort) { - rc.Response.WriteHeader(http.StatusNotImplemented) - logger.LogRequest(*rc.Request, *rc.Response, false, CacheStatusLabel[CacheStatusMiss]) - log.Errorf("Missing configuration in HandleRequest for %s (listening on :%s).", rc.Request.Host, listeningPort) + rc.DomainConfig, configFound = config.DomainConf(req.Host, rc.GetScheme()) + if !configFound || !rc.IsLegitRequest(listeningPort) { + rc.Response.SendNotImplemented() + + logger.LogRequest(rc.Request, *rc.Response, rc.ReqID, false, CacheStatusLabel[CacheStatusMiss]) + + return RequestCall{}, fmt.Errorf("Request for %s (listening on :%s) is not allowed (mostly likely it's a configuration mismatch).", rc.Request.Host, listeningPort) + } - return RequestCall{} + if rc.DomainConfig.Server.GZip { } - return rc + return rc, nil } diff --git a/server/handler/healthcheck.go b/server/handler/healthcheck.go index 07b38251..8da1dec9 100644 --- a/server/handler/healthcheck.go +++ b/server/handler/healthcheck.go @@ -20,10 +20,15 @@ import ( // HandleHealthcheck - Returns healthcheck status. func HandleHealthcheck(cfg config.Configuration) func(res http.ResponseWriter, req *http.Request) { return func(res http.ResponseWriter, req *http.Request) { - rc := initRequestParams(res, req, cfg) + rc, err := initRequestParams(res, req) + if err != nil { + rc.GetLogger().Errorf(err.Error()) + return + } + domainID := rc.DomainConfig.Server.Upstream.GetDomainID() - lwr := response.NewLoggedResponseWriter(res) + lwr := response.NewLoggedResponseWriter(res, rc.ReqID) lwr.WriteHeader(http.StatusOK) _ = lwr.WriteBody("HTTP OK\n") diff --git a/server/handler/healthcheck_test.go b/server/handler/healthcheck_test.go index 320dee38..a9471be0 100644 --- a/server/handler/healthcheck_test.go +++ b/server/handler/healthcheck_test.go @@ -52,14 +52,22 @@ func TestHealthcheckWithoutRedis(t *testing.T) { Interval: time.Duration(1), Timeout: time.Duration(1), // clears state immediately }, + Server: config.Server{ + Upstream: config.Upstream{ + Host: "testing.local", + Scheme: "https", + Endpoints: []string{utils.GetEnv("NGINX_HOST_80", "localhost:40080")}, + }, + }, } domainID := config.Config.Server.Upstream.GetDomainID() circuit_breaker.InitCircuitBreaker(domainID, config.Config.CircuitBreaker) - engine.InitConn(domainID, config.Config.Cache) + engine.InitConn(domainID, config.Config.Cache, log.StandardLogger()) engine.GetConn(domainID).Close() req, err := http.NewRequest("GET", "/healthcheck", nil) + req.Host = "testing.local" assert.Nil(t, err) rr := httptest.NewRecorder() @@ -72,7 +80,7 @@ func TestHealthcheckWithoutRedis(t *testing.T) { assert.Contains(t, rr.Body.String(), `REDIS KO`) assert.NotContains(t, rr.Body.String(), `REDIS OK`) - engine.InitConn(domainID, config.Config.Cache) + engine.InitConn(domainID, config.Config.Cache, log.StandardLogger()) } func TestHealthcheckWithRedis(t *testing.T) { @@ -90,18 +98,26 @@ func TestHealthcheckWithRedis(t *testing.T) { Interval: time.Duration(1), Timeout: time.Duration(1), // clears state immediately }, + Server: config.Server{ + Upstream: config.Upstream{ + Host: "testing.local", + Scheme: "http", + Endpoints: []string{utils.GetEnv("NGINX_HOST_80", "localhost:40080")}, + }, + }, } + domainID := config.Config.Server.Upstream.GetDomainID() + circuit_breaker.InitCircuitBreaker(domainID, config.Config.CircuitBreaker) + engine.InitConn(domainID, config.Config.Cache, log.StandardLogger()) + req, err := http.NewRequest("GET", "/healthcheck", nil) + req.Host = "testing.local" assert.Nil(t, err) rr := httptest.NewRecorder() h := http.HandlerFunc(handler.HandleHealthcheck(config.Config)) - domainID := config.Config.Server.Upstream.GetDomainID() - circuit_breaker.InitCircuitBreaker(domainID, config.Config.CircuitBreaker) - engine.InitConn(domainID, config.Config.Cache) - h.ServeHTTP(rr, req) assert.Equal(t, http.StatusOK, rr.Code) diff --git a/server/handler/http.go b/server/handler/http.go index 478f8023..75a1e3a1 100644 --- a/server/handler/http.go +++ b/server/handler/http.go @@ -16,8 +16,7 @@ import ( "os" "time" - log "github.com/sirupsen/logrus" - + "github.com/fabiocicerchia/go-proxy-cache/config" "github.com/fabiocicerchia/go-proxy-cache/server/logger" "github.com/fabiocicerchia/go-proxy-cache/server/response" "github.com/fabiocicerchia/go-proxy-cache/server/storage" @@ -41,9 +40,9 @@ var CacheStatusLabel = map[int]string{ CacheStatusStale: "STALE", } -var enableStoringResponse = true -var enableCachedResponse = true -var enableLoggingRequest = true +const enableStoringResponse = true +const enableCachedResponse = true +const enableLoggingRequest = true // DefaultTransportMaxIdleConns - Default value used for http.Transport.MaxIdleConns. var DefaultTransportMaxIdleConns int = 1000 @@ -61,28 +60,32 @@ var DefaultTransportDialTimeout time.Duration = 15 * time.Second func (rc RequestCall) HandleHTTPRequestAndProxy() { cached := CacheStatusMiss - if enableCachedResponse { + forceFresh := rc.Request.Header.Get(response.CacheBypassHeader) == "1" + if forceFresh { + rc.GetLogger().Warningf("Forcing Fresh Content on %v", rc.Request.URL.String()) + } + + if enableCachedResponse && !forceFresh { cached = rc.serveCachedContent() } if cached == CacheStatusMiss { + rc.Response.Header().Set(response.CacheStatusHeader, response.CacheStatusHeaderMiss) rc.serveReverseProxyHTTP() } if enableLoggingRequest { // HIT and STALE considered the same. - logger.LogRequest(*rc.Request, *rc.Response, cached != CacheStatusMiss, CacheStatusLabel[cached]) + logger.LogRequest(rc.Request, *rc.Response, rc.ReqID, cached != CacheStatusMiss, CacheStatusLabel[cached]) } } func (rc RequestCall) serveCachedContent() int { rcDTO := ConvertToRequestCallDTO(rc) - uriObj, err := storage.RetrieveCachedContent(rcDTO) + uriObj, err := storage.RetrieveCachedContent(rcDTO, rc.GetLogger()) if err != nil { - rc.Response.Header().Set(response.CacheStatusHeader, response.CacheStatusHeaderMiss) - - log.Warnf("Error on serving cached content: %s", err) + rc.GetLogger().Warnf("Error on serving cached content: %s", err) return CacheStatusMiss } @@ -103,9 +106,9 @@ func (rc RequestCall) serveCachedContent() int { func (rc RequestCall) serveReverseProxyHTTP() { proxyURL := rc.GetUpstreamURL() - log.Debugf("ProxyURL: %s", proxyURL.String()) - log.Debugf("Req URL: %s", rc.Request.URL.String()) - log.Debugf("Req Host: %s", rc.Request.Host) + rc.GetLogger().Debugf("ProxyURL: %s", proxyURL.String()) + rc.GetLogger().Debugf("Req URL: %s", rc.Request.URL.String()) + rc.GetLogger().Debugf("Req Host: %s", rc.Request.Host) proxy := httputil.NewSingleHostReverseProxy(&proxyURL) proxy.Transport = rc.patchProxyTransport() @@ -119,8 +122,17 @@ func (rc RequestCall) serveReverseProxyHTTP() { gpcDirector(req) } - // Forward Original Request - proxy.ServeHTTP(rc.Response, rc.Request) + serveNotModified := rc.GetResponseWithETag(proxy) + if serveNotModified { + rc.Response.SendNotModifiedResponse() + return + } + + if rc.DomainConfig.Server.GZip { + WrapResponseForGZip(rc.Response, &rc.Request) + } + + rc.Response.SendResponse() rc.storeResponse() } @@ -130,26 +142,27 @@ func (rc RequestCall) storeResponse() { return } + rcDTO := ConvertToRequestCallDTO(rc) + // Make it sync for testing // TODO: Make it customizable? if os.Getenv("GPC_SYNC_STORING") == "1" { - log.Debugf("Sync Store Response: %s", rc.Request.URL.String()) + rc.GetLogger().Debugf("Sync Store Response: %s", rc.Request.URL.String()) - rc.doStoreResponse() + doStoreResponse(rcDTO, rc.DomainConfig.Cache) return } - log.Debugf("Async Store Response: %s", rc.Request.URL.String()) + rc.GetLogger().Debugf("Async Store Response: %s", rc.Request.URL.String()) + // go rc.doStoreResponse() queue.Dispatcher.Do(func() { - rc.doStoreResponse() + doStoreResponse(rcDTO, rc.DomainConfig.Cache) }) } -func (rc RequestCall) doStoreResponse() { - rcDTO := ConvertToRequestCallDTO(rc) - - stored, err := storage.StoreGeneratedPage(rcDTO, rc.DomainConfig.Cache) +func doStoreResponse(rcDTO storage.RequestCallDTO, configCache config.Cache) { + stored, err := storage.StoreGeneratedPage(rcDTO, configCache) if !stored || err != nil { - logger.Log(*rc.Request, fmt.Sprintf("Not Stored: %v", err)) + logger.Log(rcDTO.Request, rcDTO.ReqID, fmt.Sprintf("Not Stored: %v", err)) } } diff --git a/server/handler/http_functional_test.go b/server/handler/http_functional_test.go index e6658a0d..ea18c116 100644 --- a/server/handler/http_functional_test.go +++ b/server/handler/http_functional_test.go @@ -12,13 +12,13 @@ package handler_test // Repo: https://github.com/fabiocicerchia/go-proxy-cache import ( - "fmt" "crypto/tls" "net/http" "net/http/httptest" "testing" "time" + log "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" "github.com/fabiocicerchia/go-proxy-cache/cache/engine" @@ -57,25 +57,25 @@ func getCommonConfig() config.Configuration { // --- HTTP func TestHTTPEndToEndCallRedirect(t *testing.T) { - cfg := getCommonConfig() - cfg.Cache.DB = 1 - cfg.Server.Upstream.Host = "testing.local" - cfg.Server.Upstream.Scheme = "http" - cfg.Server.Upstream.HTTP2HTTPS = true - cfg.Server.Upstream.RedirectStatusCode = 301 - domainID := cfg.Server.Upstream.GetDomainID() - balancer.InitRoundRobin(domainID, cfg.Server.Upstream.Endpoints) - circuit_breaker.InitCircuitBreaker(domainID, cfg.CircuitBreaker) - engine.InitConn(domainID, cfg.Cache) + config.Config = getCommonConfig() + config.Config.Cache.DB = 1 + config.Config.Server.Upstream.Host = "testing.local" + config.Config.Server.Upstream.Scheme = "http" + config.Config.Server.Upstream.HTTP2HTTPS = true + config.Config.Server.Upstream.RedirectStatusCode = 301 + domainID := config.Config.Server.Upstream.GetDomainID() + balancer.InitRoundRobin(domainID, config.Config.Server.Upstream.Endpoints) + circuit_breaker.InitCircuitBreaker(domainID, config.Config.CircuitBreaker) + engine.InitConn(domainID, config.Config.Cache, log.StandardLogger()) req, err := http.NewRequest("GET", "/", nil) - req.URL.Scheme = cfg.Server.Upstream.Scheme - req.URL.Host = cfg.Server.Upstream.Host - req.Host = cfg.Server.Upstream.Host + req.URL.Scheme = config.Config.Server.Upstream.Scheme + req.URL.Host = config.Config.Server.Upstream.Host + req.Host = config.Config.Server.Upstream.Host assert.Nil(t, err) rr := httptest.NewRecorder() - h := http.HandlerFunc(handler.HandleRequest(cfg)) + h := http.HandlerFunc(handler.HandleRequest) h.ServeHTTP(rr, req) @@ -87,32 +87,32 @@ func TestHTTPEndToEndCallRedirect(t *testing.T) { } func TestHTTPEndToEndCallWithoutCache(t *testing.T) { - cfg := getCommonConfig() - cfg.Cache.DB = 2 - cfg.Domains = make(config.Domains) - conf := cfg - cfg.Server.Upstream = config.Upstream{ + config.Config = getCommonConfig() + config.Config.Cache.DB = 2 + config.Config.Domains = make(config.Domains) + conf := config.Config + config.Config.Server.Upstream = config.Upstream{ Host: "www.w3.org", Scheme: "https", Endpoints: []string{"www.w3.org"}, } - cfg.Domains["www.w3.org"] = conf + config.Config.Domains["www.w3.org"] = conf - domainID := cfg.Server.Upstream.GetDomainID() - balancer.InitRoundRobin(domainID, cfg.Server.Upstream.Endpoints) - circuit_breaker.InitCircuitBreaker(domainID, cfg.CircuitBreaker) - engine.InitConn(domainID, cfg.Cache) + domainID := config.Config.Server.Upstream.GetDomainID() + balancer.InitRoundRobin(domainID, config.Config.Server.Upstream.Endpoints) + circuit_breaker.InitCircuitBreaker(domainID, config.Config.CircuitBreaker) + engine.InitConn(domainID, config.Config.Cache, log.StandardLogger()) engine.GetConn(domainID).Close() req, err := http.NewRequest("GET", "/", nil) - req.URL.Scheme = cfg.Server.Upstream.Scheme - req.URL.Host = cfg.Server.Upstream.Host - req.Host = cfg.Server.Upstream.Host + req.URL.Scheme = config.Config.Server.Upstream.Scheme + req.URL.Host = config.Config.Server.Upstream.Host + req.Host = config.Config.Server.Upstream.Host assert.Nil(t, err) rr := httptest.NewRecorder() - h := http.HandlerFunc(handler.HandleRequest(cfg)) + h := http.HandlerFunc(handler.HandleRequest) h.ServeHTTP(rr, req) @@ -130,30 +130,30 @@ func TestHTTPEndToEndCallWithoutCache(t *testing.T) { } func TestHTTPEndToEndCallWithCacheMiss(t *testing.T) { - cfg := getCommonConfig() - cfg.Cache.DB = 3 - cfg.Server.Upstream = config.Upstream{ + config.Config = getCommonConfig() + config.Config.Cache.DB = 3 + config.Config.Server.Upstream = config.Upstream{ Host: "www.w3.org", Scheme: "http", Endpoints: []string{"www.w3.org"}, } - domainID := cfg.Server.Upstream.GetDomainID() - balancer.InitRoundRobin(domainID, cfg.Server.Upstream.Endpoints) - circuit_breaker.InitCircuitBreaker(domainID, cfg.CircuitBreaker) - engine.InitConn(domainID, cfg.Cache) + domainID := config.Config.Server.Upstream.GetDomainID() + balancer.InitRoundRobin(domainID, config.Config.Server.Upstream.Endpoints) + circuit_breaker.InitCircuitBreaker(domainID, config.Config.CircuitBreaker) + engine.InitConn(domainID, config.Config.Cache, log.StandardLogger()) _, err := engine.GetConn(domainID).PurgeAll() assert.Nil(t, err) req, err := http.NewRequest("GET", "/", nil) - req.URL.Scheme = cfg.Server.Upstream.Scheme - req.URL.Host = cfg.Server.Upstream.Host - req.Host = cfg.Server.Upstream.Host + req.URL.Scheme = config.Config.Server.Upstream.Scheme + req.URL.Host = config.Config.Server.Upstream.Host + req.Host = config.Config.Server.Upstream.Host assert.Nil(t, err) rr := httptest.NewRecorder() - h := http.HandlerFunc(handler.HandleRequest(cfg)) + h := http.HandlerFunc(handler.HandleRequest) h.ServeHTTP(rr, req) @@ -171,7 +171,7 @@ func TestHTTPEndToEndCallWithCacheMiss(t *testing.T) { } func TestHTTPEndToEndCallWithCacheHit(t *testing.T) { - cfg := config.Configuration{ + config.Config = config.Configuration{ Server: config.Server{ Upstream: config.Upstream{ Host: "www.w3.org", @@ -194,25 +194,101 @@ func TestHTTPEndToEndCallWithCacheHit(t *testing.T) { }, } - domainID := cfg.Server.Upstream.GetDomainID() - balancer.InitRoundRobin(domainID, cfg.Server.Upstream.Endpoints) - circuit_breaker.InitCircuitBreaker(domainID, cfg.CircuitBreaker) - engine.InitConn(domainID, cfg.Cache) + domainID := config.Config.Server.Upstream.GetDomainID() + balancer.InitRoundRobin(domainID, config.Config.Server.Upstream.Endpoints) + circuit_breaker.InitCircuitBreaker(domainID, config.Config.CircuitBreaker) + engine.InitConn(domainID, config.Config.Cache, log.StandardLogger()) _, _ = engine.GetConn(domainID).PurgeAll() - time.Sleep(1 * time.Second) + // --- MISS + + req, err := http.NewRequest("GET", "/", nil) + req.URL.Scheme = config.Config.Server.Upstream.Scheme + req.URL.Host = config.Config.Server.Upstream.Host + req.Host = config.Config.Server.Upstream.Host + assert.Nil(t, err) + + rr := httptest.NewRecorder() + h := http.HandlerFunc(handler.HandleRequest) + + h.ServeHTTP(rr, req) + + assert.Equal(t, http.StatusOK, rr.Code) + + assert.Equal(t, "MISS", rr.HeaderMap["X-Go-Proxy-Cache-Status"][0]) + + body := rr.Body.String() + + assert.Contains(t, body, "World Wide Web Consortium (W3C)`) + assert.Contains(t, body, "\n\n") + + // --- HIT + + req, err = http.NewRequest("GET", "/", nil) + req.URL.Scheme = config.Config.Server.Upstream.Scheme + req.URL.Host = config.Config.Server.Upstream.Host + req.Host = config.Config.Server.Upstream.Host + assert.Nil(t, err) + + rr = httptest.NewRecorder() + h.ServeHTTP(rr, req) + + assert.Equal(t, http.StatusOK, rr.Code) + + assert.Equal(t, "HIT", rr.HeaderMap["X-Go-Proxy-Cache-Status"][0]) + + body = rr.Body.String() + + assert.Contains(t, body, "World Wide Web Consortium (W3C)`) + assert.Contains(t, body, "\n\n") + + tearDownHTTPFunctional() +} + +func TestHTTPEndToEndCallWithCacheBypass(t *testing.T) { + config.Config = config.Configuration{ + Server: config.Server{ + Upstream: config.Upstream{ + Host: "www.w3.org", + Scheme: "http", + Endpoints: []string{"www.w3.org"}, + }, + }, + Cache: config.Cache{ + Host: utils.GetEnv("REDIS_HOST", "localhost"), + Port: "6379", + DB: 4, + AllowedStatuses: []int{200, 301, 302}, + AllowedMethods: []string{"HEAD", "GET"}, + }, + CircuitBreaker: circuit_breaker.CircuitBreaker{ + Threshold: 2, // after 2nd request, if meet FailureRate goes open. + FailureRate: 0.5, // 1 out of 2 fails, or more + Interval: time.Duration(1), // clears counts immediately + Timeout: time.Duration(1), // clears state immediately + }, + } + + domainID := config.Config.Server.Upstream.GetDomainID() + balancer.InitRoundRobin(domainID, config.Config.Server.Upstream.Endpoints) + circuit_breaker.InitCircuitBreaker(domainID, config.Config.CircuitBreaker) + engine.InitConn(domainID, config.Config.Cache, log.StandardLogger()) + + _, _ = engine.GetConn(domainID).PurgeAll() // --- MISS req, err := http.NewRequest("GET", "/", nil) - req.URL.Scheme = cfg.Server.Upstream.Scheme - req.URL.Host = cfg.Server.Upstream.Host - req.Host = cfg.Server.Upstream.Host + req.URL.Scheme = config.Config.Server.Upstream.Scheme + req.URL.Host = config.Config.Server.Upstream.Host + req.Host = config.Config.Server.Upstream.Host assert.Nil(t, err) rr := httptest.NewRecorder() - h := http.HandlerFunc(handler.HandleRequest(cfg)) + h := http.HandlerFunc(handler.HandleRequest) h.ServeHTTP(rr, req) @@ -229,9 +305,9 @@ func TestHTTPEndToEndCallWithCacheHit(t *testing.T) { // --- HIT req, err = http.NewRequest("GET", "/", nil) - req.URL.Scheme = cfg.Server.Upstream.Scheme - req.URL.Host = cfg.Server.Upstream.Host - req.Host = cfg.Server.Upstream.Host + req.URL.Scheme = config.Config.Server.Upstream.Scheme + req.URL.Host = config.Config.Server.Upstream.Host + req.Host = config.Config.Server.Upstream.Host assert.Nil(t, err) rr = httptest.NewRecorder() @@ -247,11 +323,36 @@ func TestHTTPEndToEndCallWithCacheHit(t *testing.T) { assert.Contains(t, body, `World Wide Web Consortium (W3C)`) assert.Contains(t, body, "\n\n") + // --- BYPASS + + req, err = http.NewRequest("GET", "/", nil) + req.URL.Scheme = config.Config.Server.Upstream.Scheme + req.URL.Host = config.Config.Server.Upstream.Host + req.Host = config.Config.Server.Upstream.Host + // Need to fetch fresh content. + req.Header = http.Header{ + "X-Go-Proxy-Cache-Force-Fresh": []string{"1"}, + } + assert.Nil(t, err) + + rr = httptest.NewRecorder() + h.ServeHTTP(rr, req) + + assert.Equal(t, http.StatusOK, rr.Code) + + assert.Equal(t, "MISS", rr.HeaderMap["X-Go-Proxy-Cache-Status"][0]) + + body = rr.Body.String() + + assert.Contains(t, body, "World Wide Web Consortium (W3C)`) + assert.Contains(t, body, "\n\n") + tearDownHTTPFunctional() } func TestHTTPEndToEndCallWithCacheStale(t *testing.T) { - cfg := config.Configuration{ + config.Config = config.Configuration{ Server: config.Server{ Upstream: config.Upstream{ Host: "www.w3.org", @@ -274,26 +375,23 @@ func TestHTTPEndToEndCallWithCacheStale(t *testing.T) { }, } - domainID := cfg.Server.Upstream.GetDomainID() - balancer.InitRoundRobin(domainID, cfg.Server.Upstream.Endpoints) - circuit_breaker.InitCircuitBreaker(domainID, cfg.CircuitBreaker) - engine.InitConn(domainID, cfg.Cache) + domainID := config.Config.Server.Upstream.GetDomainID() + balancer.InitRoundRobin(domainID, config.Config.Server.Upstream.Endpoints) + circuit_breaker.InitCircuitBreaker(domainID, config.Config.CircuitBreaker) + engine.InitConn(domainID, config.Config.Cache, log.StandardLogger()) _, _ = engine.GetConn(domainID).PurgeAll() - time.Sleep(1 * time.Second) - // --- MISS req, err := http.NewRequest("GET", "/standards/", nil) - req.URL.Scheme = cfg.Server.Upstream.Scheme - req.URL.Host = cfg.Server.Upstream.Host - req.Host = cfg.Server.Upstream.Host - fmt.Println(req.URL) + req.URL.Scheme = config.Config.Server.Upstream.Scheme + req.URL.Host = config.Config.Server.Upstream.Host + req.Host = config.Config.Server.Upstream.Host assert.Nil(t, err) rr := httptest.NewRecorder() - h := http.HandlerFunc(handler.HandleRequest(cfg)) + h := http.HandlerFunc(handler.HandleRequest) h.ServeHTTP(rr, req) @@ -310,9 +408,9 @@ func TestHTTPEndToEndCallWithCacheStale(t *testing.T) { // --- HIT req, err = http.NewRequest("GET", "/standards/", nil) - req.URL.Scheme = cfg.Server.Upstream.Scheme - req.URL.Host = cfg.Server.Upstream.Host - req.Host = cfg.Server.Upstream.Host + req.URL.Scheme = config.Config.Server.Upstream.Scheme + req.URL.Host = config.Config.Server.Upstream.Host + req.Host = config.Config.Server.Upstream.Host assert.Nil(t, err) rr = httptest.NewRecorder() @@ -328,18 +426,15 @@ func TestHTTPEndToEndCallWithCacheStale(t *testing.T) { assert.Contains(t, body, `Standards - W3C`) assert.Contains(t, body, "\n") - time.Sleep(10 * time.Second) - // Manual Timeout All Fresh Keys _, _ = engine.GetConn(domainID).DelWildcard("DATA@@GET@@http://www.w3.org/standards/@@*/fresh") - time.Sleep(10 * time.Second) // --- STALE req, err = http.NewRequest("GET", "/standards/", nil) - req.URL.Scheme = cfg.Server.Upstream.Scheme - req.URL.Host = cfg.Server.Upstream.Host - req.Host = cfg.Server.Upstream.Host + req.URL.Scheme = config.Config.Server.Upstream.Scheme + req.URL.Host = config.Config.Server.Upstream.Host + req.Host = config.Config.Server.Upstream.Host assert.Nil(t, err) rr = httptest.NewRecorder() @@ -359,7 +454,7 @@ func TestHTTPEndToEndCallWithCacheStale(t *testing.T) { } func TestHTTPEndToEndCallWithHTTPSRedirect(t *testing.T) { - cfg := config.Configuration{ + config.Config = config.Configuration{ Server: config.Server{ Upstream: config.Upstream{ Host: "testing.local", @@ -370,19 +465,19 @@ func TestHTTPEndToEndCallWithHTTPSRedirect(t *testing.T) { }, }, } - cfg.Cache.DB = 6 + config.Config.Cache.DB = 6 - domainID := cfg.Server.Upstream.GetDomainID() - balancer.InitRoundRobin(domainID, cfg.Server.Upstream.Endpoints) + domainID := config.Config.Server.Upstream.GetDomainID() + balancer.InitRoundRobin(domainID, config.Config.Server.Upstream.Endpoints) req, err := http.NewRequest("GET", "/", nil) - req.URL.Scheme = cfg.Server.Upstream.Scheme - req.URL.Host = cfg.Server.Upstream.Host - req.Host = cfg.Server.Upstream.Host + req.URL.Scheme = config.Config.Server.Upstream.Scheme + req.URL.Host = config.Config.Server.Upstream.Host + req.Host = config.Config.Server.Upstream.Host assert.Nil(t, err) rr := httptest.NewRecorder() - h := http.HandlerFunc(handler.HandleRequest(cfg)) + h := http.HandlerFunc(handler.HandleRequest) h.ServeHTTP(rr, req) @@ -394,21 +489,21 @@ func TestHTTPEndToEndCallWithHTTPSRedirect(t *testing.T) { } func TestHTTPEndToEndCallWithMissingDomain(t *testing.T) { - cfg := getCommonConfig() - cfg.Cache.DB = 7 - cfg.Domains = make(config.Domains) - conf := cfg - cfg.Server.Upstream = config.Upstream{ + config.Config = getCommonConfig() + config.Config.Cache.DB = 7 + config.Config.Domains = make(config.Domains) + conf := config.Config + config.Config.Server.Upstream = config.Upstream{ Host: "www.w3.org", Scheme: "http", Endpoints: []string{"www.w3.org"}, } - cfg.Domains["www.w3.org"] = conf + config.Config.Domains["www.w3.org"] = conf - domainID := cfg.Server.Upstream.GetDomainID() - balancer.InitRoundRobin(domainID, cfg.Server.Upstream.Endpoints) - circuit_breaker.InitCircuitBreaker(domainID, cfg.CircuitBreaker) - engine.InitConn(domainID, cfg.Cache) + domainID := config.Config.Server.Upstream.GetDomainID() + balancer.InitRoundRobin(domainID, config.Config.Server.Upstream.Endpoints) + circuit_breaker.InitCircuitBreaker(domainID, config.Config.CircuitBreaker) + engine.InitConn(domainID, config.Config.Cache, log.StandardLogger()) engine.GetConn(domainID).Close() @@ -419,7 +514,7 @@ func TestHTTPEndToEndCallWithMissingDomain(t *testing.T) { assert.Nil(t, err) rr := httptest.NewRecorder() - h := http.HandlerFunc(handler.HandleRequest(cfg)) + h := http.HandlerFunc(handler.HandleRequest) h.ServeHTTP(rr, req) @@ -431,31 +526,31 @@ func TestHTTPEndToEndCallWithMissingDomain(t *testing.T) { // --- HTTPS func TestHTTPSEndToEndCallRedirect(t *testing.T) { - cfg := getCommonConfig() - cfg.Cache.DB = 8 - cfg.Server.Upstream.Host = "testing.local" - cfg.Server.Upstream.Scheme = "http" - cfg.Server.Upstream.HTTP2HTTPS = true - cfg.Server.Upstream.RedirectStatusCode = 301 - cfg.Server.Upstream.Endpoints = []string{utils.GetEnv("NGINX_HOST_443", "localhost:40443")} + config.Config = getCommonConfig() + config.Config.Cache.DB = 8 + config.Config.Server.Upstream.Host = "testing.local" + config.Config.Server.Upstream.Scheme = "http" + config.Config.Server.Upstream.HTTP2HTTPS = true + config.Config.Server.Upstream.RedirectStatusCode = 301 + config.Config.Server.Upstream.Endpoints = []string{utils.GetEnv("NGINX_HOST_443", "localhost:40443")} // This is because there's no client sending their certificate, so the handshake will be broken with a // `remote error: tls: bad certificate`. // More details on: https://www.prakharsrivastav.com/posts/from-http-to-https-using-go/ - cfg.Server.Upstream.InsecureBridge = true + config.Config.Server.Upstream.InsecureBridge = true - domainID := cfg.Server.Upstream.GetDomainID() - balancer.InitRoundRobin(domainID, cfg.Server.Upstream.Endpoints) - circuit_breaker.InitCircuitBreaker(domainID, cfg.CircuitBreaker) - engine.InitConn(domainID, cfg.Cache) + domainID := config.Config.Server.Upstream.GetDomainID() + balancer.InitRoundRobin(domainID, config.Config.Server.Upstream.Endpoints) + circuit_breaker.InitCircuitBreaker(domainID, config.Config.CircuitBreaker) + engine.InitConn(domainID, config.Config.Cache, log.StandardLogger()) req, err := http.NewRequest("GET", "/", nil) - req.URL.Scheme = cfg.Server.Upstream.Scheme - req.URL.Host = cfg.Server.Upstream.Host - req.Host = cfg.Server.Upstream.Host + req.URL.Scheme = config.Config.Server.Upstream.Scheme + req.URL.Host = config.Config.Server.Upstream.Host + req.Host = config.Config.Server.Upstream.Host assert.Nil(t, err) rr := httptest.NewRecorder() - h := http.HandlerFunc(handler.HandleRequest(cfg)) + h := http.HandlerFunc(handler.HandleRequest) h.ServeHTTP(rr, req) @@ -467,32 +562,32 @@ func TestHTTPSEndToEndCallRedirect(t *testing.T) { } func TestHTTPSEndToEndCallWithoutCache(t *testing.T) { - cfg := getCommonConfig() - cfg.Cache.DB = 9 - cfg.Domains = make(config.Domains) - conf := cfg - cfg.Server.Upstream = config.Upstream{ + config.Config = getCommonConfig() + config.Config.Cache.DB = 9 + config.Config.Domains = make(config.Domains) + conf := config.Config + config.Config.Server.Upstream = config.Upstream{ Host: "www.w3.org", Scheme: "https", Endpoints: []string{"www.w3.org"}, } - cfg.Domains["www.w3.org"] = conf + config.Config.Domains["www.w3.org"] = conf - domainID := cfg.Server.Upstream.GetDomainID() - balancer.InitRoundRobin(domainID, cfg.Server.Upstream.Endpoints) - circuit_breaker.InitCircuitBreaker(domainID, cfg.CircuitBreaker) - engine.InitConn(domainID, cfg.Cache) + domainID := config.Config.Server.Upstream.GetDomainID() + balancer.InitRoundRobin(domainID, config.Config.Server.Upstream.Endpoints) + circuit_breaker.InitCircuitBreaker(domainID, config.Config.CircuitBreaker) + engine.InitConn(domainID, config.Config.Cache, log.StandardLogger()) engine.GetConn(domainID).Close() req, err := http.NewRequest("GET", "/", nil) - req.URL.Scheme = cfg.Server.Upstream.Scheme - req.URL.Host = cfg.Server.Upstream.Host - req.Host = cfg.Server.Upstream.Host + req.URL.Scheme = config.Config.Server.Upstream.Scheme + req.URL.Host = config.Config.Server.Upstream.Host + req.Host = config.Config.Server.Upstream.Host assert.Nil(t, err) rr := httptest.NewRecorder() - h := http.HandlerFunc(handler.HandleRequest(cfg)) + h := http.HandlerFunc(handler.HandleRequest) h.ServeHTTP(rr, req) @@ -510,30 +605,30 @@ func TestHTTPSEndToEndCallWithoutCache(t *testing.T) { } func TestHTTPSEndToEndCallWithCacheMiss(t *testing.T) { - cfg := getCommonConfig() - cfg.Cache.DB = 10 - cfg.Server.Upstream = config.Upstream{ + config.Config = getCommonConfig() + config.Config.Cache.DB = 10 + config.Config.Server.Upstream = config.Upstream{ Host: "www.w3.org", Scheme: "https", Endpoints: []string{"www.w3.org"}, } - domainID := cfg.Server.Upstream.GetDomainID() - balancer.InitRoundRobin(domainID, cfg.Server.Upstream.Endpoints) - circuit_breaker.InitCircuitBreaker(domainID, cfg.CircuitBreaker) - engine.InitConn(domainID, cfg.Cache) + domainID := config.Config.Server.Upstream.GetDomainID() + balancer.InitRoundRobin(domainID, config.Config.Server.Upstream.Endpoints) + circuit_breaker.InitCircuitBreaker(domainID, config.Config.CircuitBreaker) + engine.InitConn(domainID, config.Config.Cache, log.StandardLogger()) _, err := engine.GetConn(domainID).PurgeAll() assert.Nil(t, err) req, err := http.NewRequest("GET", "/", nil) - req.URL.Scheme = cfg.Server.Upstream.Scheme - req.URL.Host = cfg.Server.Upstream.Host - req.Host = cfg.Server.Upstream.Host + req.URL.Scheme = config.Config.Server.Upstream.Scheme + req.URL.Host = config.Config.Server.Upstream.Host + req.Host = config.Config.Server.Upstream.Host assert.Nil(t, err) rr := httptest.NewRecorder() - h := http.HandlerFunc(handler.HandleRequest(cfg)) + h := http.HandlerFunc(handler.HandleRequest) h.ServeHTTP(rr, req) @@ -551,7 +646,7 @@ func TestHTTPSEndToEndCallWithCacheMiss(t *testing.T) { } func TestHTTPSEndToEndCallWithCacheHit(t *testing.T) { - cfg := config.Configuration{ + config.Config = config.Configuration{ Server: config.Server{ Upstream: config.Upstream{ Host: "www.w3.org", @@ -574,24 +669,24 @@ func TestHTTPSEndToEndCallWithCacheHit(t *testing.T) { }, } - domainID := cfg.Server.Upstream.GetDomainID() - balancer.InitRoundRobin(domainID, cfg.Server.Upstream.Endpoints) - circuit_breaker.InitCircuitBreaker(domainID, cfg.CircuitBreaker) - engine.InitConn(domainID, cfg.Cache) + domainID := config.Config.Server.Upstream.GetDomainID() + balancer.InitRoundRobin(domainID, config.Config.Server.Upstream.Endpoints) + circuit_breaker.InitCircuitBreaker(domainID, config.Config.CircuitBreaker) + engine.InitConn(domainID, config.Config.Cache, log.StandardLogger()) _, _ = engine.GetConn(domainID).PurgeAll() // --- MISS req, err := http.NewRequest("GET", "/", nil) - req.URL.Scheme = cfg.Server.Upstream.Scheme - req.URL.Host = cfg.Server.Upstream.Host - req.Host = cfg.Server.Upstream.Host + req.URL.Scheme = config.Config.Server.Upstream.Scheme + req.URL.Host = config.Config.Server.Upstream.Host + req.Host = config.Config.Server.Upstream.Host req.TLS = &tls.ConnectionState{} // mock a fake https assert.Nil(t, err) rr := httptest.NewRecorder() - h := http.HandlerFunc(handler.HandleRequest(cfg)) + h := http.HandlerFunc(handler.HandleRequest) h.ServeHTTP(rr, req) @@ -605,14 +700,12 @@ func TestHTTPSEndToEndCallWithCacheHit(t *testing.T) { assert.Contains(t, body, `World Wide Web Consortium (W3C)`) assert.Contains(t, body, "\n\n") - time.Sleep(1 * time.Second) - // --- HIT req, err = http.NewRequest("GET", "/", nil) - req.URL.Scheme = cfg.Server.Upstream.Scheme - req.URL.Host = cfg.Server.Upstream.Host - req.Host = cfg.Server.Upstream.Host + req.URL.Scheme = config.Config.Server.Upstream.Scheme + req.URL.Host = config.Config.Server.Upstream.Host + req.Host = config.Config.Server.Upstream.Host req.TLS = &tls.ConnectionState{} // mock a fake https assert.Nil(t, err) @@ -633,21 +726,21 @@ func TestHTTPSEndToEndCallWithCacheHit(t *testing.T) { } func TestHTTPSEndToEndCallWithMissingDomain(t *testing.T) { - cfg := getCommonConfig() - cfg.Cache.DB = 12 - cfg.Domains = make(config.Domains) - conf := cfg - cfg.Server.Upstream = config.Upstream{ + config.Config = getCommonConfig() + config.Config.Cache.DB = 12 + config.Config.Domains = make(config.Domains) + conf := config.Config + config.Config.Server.Upstream = config.Upstream{ Host: "www.w3.org", Scheme: "https", Endpoints: []string{"www.w3.org"}, } - cfg.Domains["www.w3.org"] = conf + config.Config.Domains["www.w3.org"] = conf - domainID := cfg.Server.Upstream.GetDomainID() - balancer.InitRoundRobin(domainID, cfg.Server.Upstream.Endpoints) - circuit_breaker.InitCircuitBreaker(domainID, cfg.CircuitBreaker) - engine.InitConn(domainID, cfg.Cache) + domainID := config.Config.Server.Upstream.GetDomainID() + balancer.InitRoundRobin(domainID, config.Config.Server.Upstream.Endpoints) + circuit_breaker.InitCircuitBreaker(domainID, config.Config.CircuitBreaker) + engine.InitConn(domainID, config.Config.Cache, log.StandardLogger()) engine.GetConn(domainID).Close() @@ -658,7 +751,7 @@ func TestHTTPSEndToEndCallWithMissingDomain(t *testing.T) { assert.Nil(t, err) rr := httptest.NewRecorder() - h := http.HandlerFunc(handler.HandleRequest(cfg)) + h := http.HandlerFunc(handler.HandleRequest) h.ServeHTTP(rr, req) diff --git a/server/handler/http_unit_test.go b/server/handler/http_unit_test.go index 50956e88..c40c0e3b 100644 --- a/server/handler/http_unit_test.go +++ b/server/handler/http_unit_test.go @@ -25,7 +25,7 @@ import ( ) func TestProxyCallOneItemInLB(t *testing.T) { - config.Config = config.Configuration{ + cfg := config.Configuration{ Server: config.Server{ Upstream: config.Upstream{ Host: "developer.mozilla.org", @@ -35,7 +35,7 @@ func TestProxyCallOneItemInLB(t *testing.T) { }, } - reqMock := &http.Request{ + reqMock := http.Request{ Proto: "HTTPS", Method: "POST", RemoteAddr: "127.0.0.1", @@ -45,12 +45,12 @@ func TestProxyCallOneItemInLB(t *testing.T) { }, } - domainID := config.Config.Server.Upstream.GetDomainID() - balancer.InitRoundRobin(domainID, config.Config.Server.Upstream.Endpoints) + domainID := cfg.Server.Upstream.GetDomainID() + balancer.InitRoundRobin(domainID, cfg.Server.Upstream.Endpoints) - r := handler.RequestCall{Request: reqMock, DomainConfig: &config.Config} + r := handler.RequestCall{Request: reqMock, DomainConfig: cfg} proxyURL := r.GetUpstreamURL() - r.ProxyDirector(r.Request) + r.ProxyDirector(&r.Request) assert.Equal(t, "localhost", r.Request.Header.Get("X-Forwarded-Host")) assert.Equal(t, "http", r.Request.Header.Get("X-Forwarded-Proto")) @@ -61,7 +61,7 @@ func TestProxyCallOneItemInLB(t *testing.T) { } func TestProxyCallOneItemWithPortInLB(t *testing.T) { - config.Config = config.Configuration{ + cfg := config.Configuration{ Server: config.Server{ Upstream: config.Upstream{ Host: "developer.mozilla.org", @@ -71,7 +71,7 @@ func TestProxyCallOneItemWithPortInLB(t *testing.T) { }, } - reqMock := &http.Request{ + reqMock := http.Request{ Proto: "HTTPS", Method: "POST", RemoteAddr: "127.0.0.1", @@ -81,12 +81,12 @@ func TestProxyCallOneItemWithPortInLB(t *testing.T) { }, } - domainID := config.Config.Server.Upstream.GetDomainID() - balancer.InitRoundRobin(domainID, config.Config.Server.Upstream.Endpoints) + domainID := cfg.Server.Upstream.GetDomainID() + balancer.InitRoundRobin(domainID, cfg.Server.Upstream.Endpoints) - r := handler.RequestCall{Request: reqMock, DomainConfig: &config.Config} + r := handler.RequestCall{Request: reqMock, DomainConfig: cfg} proxyURL := r.GetUpstreamURL() - r.ProxyDirector(r.Request) + r.ProxyDirector(&r.Request) assert.Equal(t, "localhost", r.Request.Header.Get("X-Forwarded-Host")) assert.Equal(t, "http", r.Request.Header.Get("X-Forwarded-Proto")) @@ -97,7 +97,7 @@ func TestProxyCallOneItemWithPortInLB(t *testing.T) { } func TestProxyCallThreeItemsInLB(t *testing.T) { - config.Config = config.Configuration{ + cfg := config.Configuration{ Server: config.Server{ Upstream: config.Upstream{ Host: "developer.mozilla.org", @@ -107,12 +107,12 @@ func TestProxyCallThreeItemsInLB(t *testing.T) { }, } - domainID := config.Config.Server.Upstream.GetDomainID() - balancer.InitRoundRobin(domainID, config.Config.Server.Upstream.Endpoints) + domainID := cfg.Server.Upstream.GetDomainID() + balancer.InitRoundRobin(domainID, cfg.Server.Upstream.Endpoints) // --- FIRST ROUND - reqMock := &http.Request{ + reqMock := http.Request{ Proto: "HTTPS", Method: "POST", RemoteAddr: "127.0.0.1", @@ -122,9 +122,9 @@ func TestProxyCallThreeItemsInLB(t *testing.T) { }, } - r := handler.RequestCall{Request: reqMock, DomainConfig: &config.Config} + r := handler.RequestCall{Request: reqMock, DomainConfig: cfg} proxyURL := r.GetUpstreamURL() - r.ProxyDirector(r.Request) + r.ProxyDirector(&r.Request) assert.Equal(t, "localhost", r.Request.Header.Get("X-Forwarded-Host")) assert.Equal(t, "http", r.Request.Header.Get("X-Forwarded-Proto")) @@ -135,7 +135,7 @@ func TestProxyCallThreeItemsInLB(t *testing.T) { // --- SECOND ROUND - reqMock = &http.Request{ + reqMock = http.Request{ Proto: "HTTPS", Method: "POST", RemoteAddr: "127.0.0.1", @@ -145,9 +145,9 @@ func TestProxyCallThreeItemsInLB(t *testing.T) { }, } - r = handler.RequestCall{Request: reqMock, DomainConfig: &config.Config} + r = handler.RequestCall{Request: reqMock, DomainConfig: cfg} proxyURL = r.GetUpstreamURL() - r.ProxyDirector(r.Request) + r.ProxyDirector(&r.Request) assert.Equal(t, "localhost", r.Request.Header.Get("X-Forwarded-Host")) assert.Equal(t, "http", r.Request.Header.Get("X-Forwarded-Proto")) @@ -158,7 +158,7 @@ func TestProxyCallThreeItemsInLB(t *testing.T) { } func TestXForwardedFor(t *testing.T) { - config.Config = config.Configuration{ + cfg := config.Configuration{ Server: config.Server{ Upstream: config.Upstream{ Host: "developer.mozilla.org", @@ -168,7 +168,7 @@ func TestXForwardedFor(t *testing.T) { }, } - reqMock := &http.Request{ + reqMock := http.Request{ Proto: "HTTPS", Method: "POST", RemoteAddr: "127.0.0.1", @@ -180,12 +180,12 @@ func TestXForwardedFor(t *testing.T) { TLS: &tls.ConnectionState{}, // mock a fake https } - domainID := config.Config.Server.Upstream.GetDomainID() - balancer.InitRoundRobin(domainID, config.Config.Server.Upstream.Endpoints) + domainID := cfg.Server.Upstream.GetDomainID() + balancer.InitRoundRobin(domainID, cfg.Server.Upstream.Endpoints) - r := handler.RequestCall{Request: reqMock, DomainConfig: &config.Config} + r := handler.RequestCall{Request: reqMock, DomainConfig: cfg} _ = r.GetUpstreamURL() - r.ProxyDirector(r.Request) + r.ProxyDirector(&r.Request) assert.Equal(t, "https", r.Request.Header.Get("X-Forwarded-Proto")) assert.Equal(t, "192.168.1.1, 127.0.0.1", r.Request.Header.Get("X-Forwarded-For")) diff --git a/server/handler/model.go b/server/handler/model.go index e169729f..2d511c71 100644 --- a/server/handler/model.go +++ b/server/handler/model.go @@ -14,6 +14,7 @@ import ( "net/url" "strings" + log "github.com/sirupsen/logrus" "github.com/yhat/wsutil" "github.com/fabiocicerchia/go-proxy-cache/config" @@ -34,9 +35,28 @@ var SchemeWSS string = "wss" // RequestCall - Main object containing request and response. type RequestCall struct { + ReqID string Response *response.LoggedResponseWriter - Request *http.Request - DomainConfig *config.Configuration + Request http.Request + DomainConfig config.Configuration +} + +// GetLogger - Get logger instance with RequestID. +func (rc RequestCall) GetLogger() *log.Entry { + return log.WithFields(log.Fields{ + "ReqID": rc.ReqID, + }) +} + +// IsLegitRequest - Check whether a request is bound on the right Host and Port. +func (rc RequestCall) IsLegitRequest(listeningPort string) bool { + hostMatch := rc.DomainConfig.Server.Upstream.Host == rc.GetHostname() + legitPort := isLegitPort(rc.DomainConfig.Server.Port, listeningPort) + + rc.GetLogger().Debugf("Is Hostname matching Request and Configuration? %v - Request: %s - Config: %s", hostMatch, rc.GetHostname(), rc.DomainConfig.Server.Upstream.Host) + rc.GetLogger().Debugf("Is Port matching Request and Configuration? %v - Request: %s - Config: %s", legitPort, listeningPort, rc.DomainConfig.Server.Port) + + return hostMatch && legitPort } // GetRequestURL - Returns the valid Request URL (with Scheme and Host). @@ -76,12 +96,7 @@ func (rc RequestCall) GetScheme() string { return SchemeHTTP } -// GetConfiguredScheme - Returns configured request scheme (could be wildcard). -func (rc RequestCall) GetConfiguredScheme() string { - return rc.DomainConfig.Server.Upstream.Scheme -} - // IsWebSocket - Checks whether a request is a websocket. func (rc RequestCall) IsWebSocket() bool { - return wsutil.IsWebSocketRequest(rc.Request) + return wsutil.IsWebSocketRequest(&rc.Request) // TODO: don't like the reference } diff --git a/server/handler/purge.go b/server/handler/purge.go index 88d1de88..472b61fb 100644 --- a/server/handler/purge.go +++ b/server/handler/purge.go @@ -12,8 +12,8 @@ package handler import ( "net/http" + "github.com/fabiocicerchia/go-proxy-cache/server/logger" "github.com/fabiocicerchia/go-proxy-cache/server/storage" - log "github.com/sirupsen/logrus" ) // HandlePurge - Purges the cache for the requested URI. @@ -22,14 +22,18 @@ func (rc RequestCall) HandlePurge() { status, err := storage.PurgeCachedContent(rc.DomainConfig.Server.Upstream, rcDTO) if !status || err != nil { - rc.Response.WriteHeader(http.StatusNotFound) + rc.Response.ForceWriteHeader(http.StatusNotFound) _ = rc.Response.WriteBody("KO") - log.Warnf("URL Not Purged %s: %v\n", rc.Request.URL.String(), err) + rc.GetLogger().Warnf("URL Not Purged %s: %v\n", rc.Request.URL.String(), err) return } - rc.Response.WriteHeader(http.StatusOK) + rc.Response.ForceWriteHeader(http.StatusOK) _ = rc.Response.WriteBody("OK") + + if enableLoggingRequest { + logger.LogRequest(rc.Request, *rc.Response, rc.ReqID, false, "-") + } } diff --git a/server/handler/purge_test.go b/server/handler/purge_test.go index 15e3777d..a6c34267 100644 --- a/server/handler/purge_test.go +++ b/server/handler/purge_test.go @@ -18,13 +18,15 @@ import ( "testing" "time" + log "github.com/sirupsen/logrus" + "github.com/stretchr/testify/assert" + "github.com/fabiocicerchia/go-proxy-cache/cache/engine" "github.com/fabiocicerchia/go-proxy-cache/config" "github.com/fabiocicerchia/go-proxy-cache/server/balancer" "github.com/fabiocicerchia/go-proxy-cache/server/handler" "github.com/fabiocicerchia/go-proxy-cache/utils" circuit_breaker "github.com/fabiocicerchia/go-proxy-cache/utils/circuit-breaker" - "github.com/stretchr/testify/assert" ) func TestEndToEndCallPurgeDoNothing(t *testing.T) { @@ -55,7 +57,7 @@ func TestEndToEndCallPurgeDoNothing(t *testing.T) { domainID := config.Config.Server.Upstream.GetDomainID() circuit_breaker.InitCircuitBreaker(domainID, config.Config.CircuitBreaker) - engine.InitConn(domainID, config.Config.Cache) + engine.InitConn(domainID, config.Config.Cache, log.StandardLogger()) // --- PURGE @@ -66,7 +68,7 @@ func TestEndToEndCallPurgeDoNothing(t *testing.T) { assert.Nil(t, err) rr := httptest.NewRecorder() - h := http.HandlerFunc(handler.HandleRequest(config.Config)) + h := http.HandlerFunc(handler.HandleRequest) _, err = engine.GetConn(domainID).PurgeAll() assert.Nil(t, err) @@ -78,8 +80,6 @@ func TestEndToEndCallPurgeDoNothing(t *testing.T) { body := rr.Body.String() assert.Equal(t, body, "KO") - - time.Sleep(1 * time.Second) } func TestEndToEndCallPurge(t *testing.T) { @@ -111,7 +111,7 @@ func TestEndToEndCallPurge(t *testing.T) { domainID := config.Config.Server.Upstream.GetDomainID() balancer.InitRoundRobin(domainID, config.Config.Server.Upstream.Endpoints) circuit_breaker.InitCircuitBreaker(domainID, config.Config.CircuitBreaker) - engine.InitConn(domainID, config.Config.Cache) + engine.InitConn(domainID, config.Config.Cache, log.StandardLogger()) // --- MISS @@ -123,7 +123,7 @@ func TestEndToEndCallPurge(t *testing.T) { assert.Nil(t, err) rr := httptest.NewRecorder() - h := http.HandlerFunc(handler.HandleRequest(config.Config)) + h := http.HandlerFunc(handler.HandleRequest) _, err = engine.GetConn(domainID).PurgeAll() assert.Nil(t, err) @@ -150,7 +150,7 @@ func TestEndToEndCallPurge(t *testing.T) { assert.Nil(t, err) rr = httptest.NewRecorder() - h = http.HandlerFunc(handler.HandleRequest(config.Config)) + h = http.HandlerFunc(handler.HandleRequest) h.ServeHTTP(rr, req) assert.Equal(t, http.StatusOK, rr.Code) @@ -181,8 +181,6 @@ func TestEndToEndCallPurge(t *testing.T) { assert.Equal(t, "OK", body) - time.Sleep(1 * time.Second) - // --- MISS req, err = http.NewRequest("GET", "/", nil) diff --git a/server/handler/redirect.go b/server/handler/redirect.go index b5951a3d..7cbf10f5 100644 --- a/server/handler/redirect.go +++ b/server/handler/redirect.go @@ -11,8 +11,6 @@ package handler import ( "net/http" - - log "github.com/sirupsen/logrus" ) // RedirectToHTTPS - Redirects from HTTP to HTTPS. @@ -20,7 +18,8 @@ func (rc RequestCall) RedirectToHTTPS() { targetURL := rc.GetRequestURL() targetURL.Scheme = SchemeHTTPS - log.Infof("Redirect to: %s", targetURL.String()) + rc.GetLogger().Infof("Redirect to: %s", targetURL.String()) - http.Redirect(rc.Response, rc.Request, targetURL.String(), rc.DomainConfig.Server.Upstream.RedirectStatusCode) + // Just write to client, no need to cache this response. + http.Redirect(rc.Response.ResponseWriter, &rc.Request, targetURL.String(), rc.DomainConfig.Server.Upstream.RedirectStatusCode) } diff --git a/server/handler/utils.go b/server/handler/utils.go index ce31e16e..514788c3 100644 --- a/server/handler/utils.go +++ b/server/handler/utils.go @@ -15,9 +15,12 @@ import ( "net" "net/http" "net/url" + "os" "strconv" "strings" + log "github.com/sirupsen/logrus" + "github.com/fabiocicerchia/go-proxy-cache/cache" "github.com/fabiocicerchia/go-proxy-cache/config" "github.com/fabiocicerchia/go-proxy-cache/server/balancer" @@ -27,10 +30,17 @@ import ( // ConvertToRequestCallDTO - Generates a storage DTO containing request, response and cache settings. func ConvertToRequestCallDTO(rc RequestCall) storage.RequestCallDTO { + responseHeaders := http.Header{} + if rc.Response != nil { + responseHeaders = rc.Response.Header() + } + return storage.RequestCallDTO{ + ReqID: rc.ReqID, Response: *rc.Response, - Request: *rc.Request, + Request: rc.Request, CacheObject: cache.Object{ + ReqID: rc.ReqID, AllowedStatuses: rc.DomainConfig.Cache.AllowedStatuses, AllowedMethods: rc.DomainConfig.Cache.AllowedMethods, DomainID: rc.DomainConfig.Server.Upstream.GetDomainID(), @@ -39,7 +49,7 @@ func ConvertToRequestCallDTO(rc RequestCall) storage.RequestCallDTO { Method: rc.Request.Method, StatusCode: rc.Response.StatusCode, RequestHeaders: rc.Request.Header, - ResponseHeaders: rc.Response.Header(), + ResponseHeaders: responseHeaders, Content: rc.Response.Content, }, }, @@ -59,6 +69,12 @@ func getListeningPort(ctx context.Context) string { } func isLegitPort(port config.Port, listeningPort string) bool { + // When running the functional tests there's no server listening (so no port open). + if os.Getenv("TESTING") == "1" && listeningPort == "" { + log.Warn("Testing Environment found, and listening port is empty") + return true + } + return port.HTTP == listeningPort || port.HTTPS == listeningPort } @@ -132,9 +148,9 @@ func (rc RequestCall) ProxyDirector(req *http.Request) { // proxy server, r.URL.Host is the host of the target server and r.Host is // the host of the proxy server itself. // Ref: https://stackoverflow.com/a/42926149/888162 - rc.Request.Header.Set("X-Forwarded-Host", rc.Request.Header.Get("Host")) + req.Header.Set("X-Forwarded-Host", rc.Request.Header.Get("Host")) - rc.Request.Header.Set("X-Forwarded-Proto", rc.GetScheme()) + req.Header.Set("X-Forwarded-Proto", rc.GetScheme()) previousXForwardedFor := rc.Request.Header.Get("X-Forwarded-For") clientIP := utils.StripPort(rc.Request.RemoteAddr) @@ -144,7 +160,7 @@ func (rc RequestCall) ProxyDirector(req *http.Request) { xForwardedFor = previousXForwardedFor + ", " + xForwardedFor } - rc.Request.Header.Set("X-Forwarded-For", xForwardedFor) + req.Header.Set("X-Forwarded-For", xForwardedFor) - rc.Request.Host = host + req.Host = host } diff --git a/server/handler/utils_test.go b/server/handler/utils_test.go new file mode 100644 index 00000000..86e51041 --- /dev/null +++ b/server/handler/utils_test.go @@ -0,0 +1,54 @@ +// +build all unit + +package handler_test + +// __ +// .-----.-----.______.-----.----.-----.--.--.--.--.______.----.---.-.----| |--.-----. +// | _ | _ |______| _ | _| _ |_ _| | |______| __| _ | __| | -__| +// |___ |_____| | __|__| |_____|__.__|___ | |____|___._|____|__|__|_____| +// |_____| |__| |_____| +// +// Copyright (c) 2020 Fabio Cicerchia. https://fabiocicerchia.it. MIT License +// Repo: https://github.com/fabiocicerchia/go-proxy-cache + +import ( + "net/http" + "net/url" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/fabiocicerchia/go-proxy-cache/config" + "github.com/fabiocicerchia/go-proxy-cache/server/balancer" + "github.com/fabiocicerchia/go-proxy-cache/server/handler" +) + +func TestGetUpstreamURLWithWildcard(t *testing.T) { + cfg := config.Configuration{ + Server: config.Server{ + Upstream: config.Upstream{ + Host: "developer.mozilla.org", + Scheme: "*", // emulate config.copyOverWithUpstream:179 + Endpoints: []string{"server1"}, + }, + }, + } + + reqMock := http.Request{ + Method: "POST", + RemoteAddr: "127.0.0.1", + URL: &url.URL{Path: "/path/to/file"}, + Header: http.Header{ + "Host": []string{"localhost"}, + }, + } + + domainID := cfg.Server.Upstream.GetDomainID() + balancer.InitRoundRobin(domainID, cfg.Server.Upstream.Endpoints) + + r := handler.RequestCall{Request: reqMock, DomainConfig: cfg} + proxyURL := r.GetUpstreamURL() + + assert.Equal(t, "server1:80", proxyURL.Host) + assert.Equal(t, "http", proxyURL.Scheme) +} diff --git a/server/handler/ws.go b/server/handler/ws.go index f4b379b1..66ad4f0c 100644 --- a/server/handler/ws.go +++ b/server/handler/ws.go @@ -12,9 +12,9 @@ package handler import ( "net/http" - "github.com/fabiocicerchia/go-proxy-cache/server/logger" - log "github.com/sirupsen/logrus" "github.com/yhat/wsutil" + + "github.com/fabiocicerchia/go-proxy-cache/server/logger" ) // HandleWSRequestAndProxy - Handles the websocket requests and proxies to backend server. @@ -22,16 +22,16 @@ func (rc RequestCall) HandleWSRequestAndProxy() { rc.serveReverseProxyWS() if enableLoggingRequest { - logger.LogRequest(*rc.Request, *rc.Response, false, CacheStatusLabel[CacheStatusMiss]) + logger.LogRequest(rc.Request, *rc.Response, rc.ReqID, false, CacheStatusLabel[CacheStatusMiss]) } } func (rc RequestCall) serveReverseProxyWS() { proxyURL := rc.GetUpstreamURL() - log.Debugf("ProxyURL: %s", proxyURL.String()) - log.Debugf("Req URL: %s", rc.Request.URL.String()) - log.Debugf("Req Host: %s", rc.Request.Host) + rc.GetLogger().Debugf("ProxyURL: %s", proxyURL.String()) + rc.GetLogger().Debugf("Req URL: %s", rc.Request.URL.String()) + rc.GetLogger().Debugf("Req Host: %s", rc.Request.Host) proxy := wsutil.NewSingleHostReverseProxy(&proxyURL) @@ -48,5 +48,5 @@ func (rc RequestCall) serveReverseProxyWS() { proxy.Dial = transport.Dial proxy.TLSClientConfig = transport.TLSClientConfig - proxy.ServeHTTP(rc.Response, rc.Request) + proxy.ServeHTTP(rc.Response, &rc.Request) } diff --git a/server/handler/ws_test.go b/server/handler/ws_test.go index 88b99e0e..64ddba1f 100644 --- a/server/handler/ws_test.go +++ b/server/handler/ws_test.go @@ -18,13 +18,15 @@ import ( "testing" "time" + log "github.com/sirupsen/logrus" + "github.com/stretchr/testify/assert" + "github.com/fabiocicerchia/go-proxy-cache/cache/engine" "github.com/fabiocicerchia/go-proxy-cache/config" "github.com/fabiocicerchia/go-proxy-cache/server/balancer" "github.com/fabiocicerchia/go-proxy-cache/server/handler" "github.com/fabiocicerchia/go-proxy-cache/utils" circuit_breaker "github.com/fabiocicerchia/go-proxy-cache/utils/circuit-breaker" - "github.com/stretchr/testify/assert" ) func TestEndToEndHandleWSRequestAndProxy(t *testing.T) { @@ -56,7 +58,7 @@ func TestEndToEndHandleWSRequestAndProxy(t *testing.T) { domainID := config.Config.Server.Upstream.GetDomainID() balancer.InitRoundRobin(domainID, config.Config.Server.Upstream.Endpoints) circuit_breaker.InitCircuitBreaker(domainID, config.Config.CircuitBreaker) - engine.InitConn(domainID, config.Config.Cache) + engine.InitConn(domainID, config.Config.Cache, log.StandardLogger()) // --- WEBSOCKET @@ -71,7 +73,7 @@ func TestEndToEndHandleWSRequestAndProxy(t *testing.T) { assert.Nil(t, err) rr := httptest.NewRecorder() - h := http.HandlerFunc(handler.HandleRequest(config.Config)) + h := http.HandlerFunc(handler.HandleRequest) _, err = engine.GetConn(domainID).PurgeAll() assert.Nil(t, err) @@ -114,7 +116,7 @@ func TestEndToEndHandleWSRequestAndProxySecure(t *testing.T) { domainID := config.Config.Server.Upstream.GetDomainID() balancer.InitRoundRobin(domainID, config.Config.Server.Upstream.Endpoints) circuit_breaker.InitCircuitBreaker(domainID, config.Config.CircuitBreaker) - engine.InitConn(domainID, config.Config.Cache) + engine.InitConn(domainID, config.Config.Cache, log.StandardLogger()) // --- WEBSOCKET @@ -130,7 +132,7 @@ func TestEndToEndHandleWSRequestAndProxySecure(t *testing.T) { assert.Nil(t, err) rr := httptest.NewRecorder() - h := http.HandlerFunc(handler.HandleRequest(config.Config)) + h := http.HandlerFunc(handler.HandleRequest) _, err = engine.GetConn(domainID).PurgeAll() assert.Nil(t, err) diff --git a/server/logger/log.go b/server/logger/log.go index 5623d662..c5bc41c5 100644 --- a/server/logger/log.go +++ b/server/logger/log.go @@ -21,18 +21,17 @@ import ( "github.com/fabiocicerchia/go-proxy-cache/config" "github.com/fabiocicerchia/go-proxy-cache/server/response" "github.com/fabiocicerchia/go-proxy-cache/utils" - "github.com/fabiocicerchia/go-proxy-cache/utils/slice" ) // Log - Logs against a requested URL. -func Log(req http.Request, message string) { +func Log(req http.Request, reqID string, message string) { logLine := fmt.Sprintf("%s %s %s - %s", req.Proto, req.Method, req.URL.String(), message) - log.Info(logLine) + log.WithFields(log.Fields{"ReqID": reqID}).Info(logLine) } // LogRequest - Logs the requested URL. -func LogRequest(req http.Request, lwr response.LoggedResponseWriter, cached bool, cached_label string) { +func LogRequest(req http.Request, lwr response.LoggedResponseWriter, reqID string, cached bool, cached_label string) { // NOTE: THIS IS FOR EVERY DOMAIN, NO DOMAIN OVERRIDE. // WHEN SHARING SAME PORT NO CUSTOM OVERRIDES ON CRITICAL SETTINGS. logLine := config.Config.Log.Format @@ -56,7 +55,7 @@ func LogRequest(req http.Request, lwr response.LoggedResponseWriter, cached bool `$request_method`, method, `$request`, req.URL.String(), `$status`, strconv.Itoa(lwr.StatusCode), - `$body_bytes_sent`, strconv.Itoa(slice.LenSliceBytes(lwr.Content)), + `$body_bytes_sent`, strconv.Itoa(lwr.Content.Len()), `$http_referer`, req.Referer(), `$http_user_agent`, req.UserAgent(), `$cached_status_label`, cached_label, @@ -65,21 +64,18 @@ func LogRequest(req http.Request, lwr response.LoggedResponseWriter, cached bool logLine = r.Replace(logLine) - log.Info(logLine) + log.WithFields(log.Fields{"ReqID": reqID}).Info(logLine) } // LogSetup - Logs the env variables required for a reverse proxy. func LogSetup(server config.Server) { forwardHost := utils.IfEmpty(server.Upstream.Host, "*") forwardProto := server.Upstream.Scheme - lbEndpointList := server.Upstream.Endpoints - log.Infof("Server will run on: %s and %s\n", server.Port.HTTP, server.Port.HTTPS) - - if len(lbEndpointList) == 0 { - log.Infof("Redirecting to url: %s://%s -> VOID\n", forwardProto, forwardHost) - return + lbEndpointList := fmt.Sprintf("%v", server.Upstream.Endpoints) + if len(server.Upstream.Endpoints) == 0 { + lbEndpointList = "VOID" } - log.Infof("Redirecting to url: %s://%s -> %v\n", forwardProto, forwardHost, lbEndpointList) + log.Infof("Server will run on :%s and :%s and redirects to url: %s://%s -> %s\n", server.Port.HTTP, server.Port.HTTPS, forwardProto, forwardHost, lbEndpointList) } diff --git a/server/logger/log_test.go b/server/logger/log_test.go index 4db4aca0..c1f4e61b 100644 --- a/server/logger/log_test.go +++ b/server/logger/log_test.go @@ -49,9 +49,9 @@ func TestLogMessage(t *testing.T) { }, } - logger.Log(reqMock, "message") + logger.Log(reqMock, "TestLogMessage", "message") - expectedOut := `time=" " level=info msg="HTTPS POST /path/to/file - message"` + "\n" + expectedOut := `time=" " level=info msg="HTTPS POST /path/to/file - message" ReqID=TestLogMessage` + "\n" assert.Equal(t, expectedOut, buf.String()) @@ -88,9 +88,9 @@ func TestLogRequest(t *testing.T) { }, } - logger.LogRequest(reqMock, lwrMock, true, "HIT") + logger.LogRequest(reqMock, lwrMock, "TestLogRequest", true, "HIT") - expectedOut := `time=" " level=info msg="example.org - 127.0.0.1 - - ? ? \"/path/to/file\" 404 7 \"https://www.google.com\" \"GoProxyCache\" true HIT"` + "\n" + expectedOut := `time=" " level=info msg="example.org - 127.0.0.1 - - ? ? \"/path/to/file\" 404 7 \"https://www.google.com\" \"GoProxyCache\" true HIT" ReqID=TestLogRequest` + "\n" assert.Equal(t, expectedOut, buf.String()) @@ -106,7 +106,7 @@ func TestLogSetup(t *testing.T) { log.SetOutput(os.Stderr) }() - config.Config = config.Configuration{ + cfg := config.Configuration{ Server: config.Server{ Port: config.Port{ HTTP: "80", @@ -118,16 +118,54 @@ func TestLogSetup(t *testing.T) { Endpoints: []string{"1.2.3.4", "8.8.8.8"}, }, }, + } + config.Config = config.Configuration{ + Log: config.Log{ + TimeFormat: "2006/01/02 15:04:05", + Format: `$host - $remote_addr - $remote_user $protocol $request_method "$request" $status $body_bytes_sent "$http_referer" "$http_user_agent" $cached_status $cached_status_label`, + }, + } + + logger.LogSetup(cfg.Server) + + expectedOut := `time=" " level=info msg="Server will run on :80 and :443 and redirects to url: https://www.google.com -> [1.2.3.4 8.8.8.8]\n"` + "\n" + assert.Equal(t, expectedOut, buf.String()) + + tearDownLog() +} + +func TestLogSetupWithoutEndpoints(t *testing.T) { + setUpLog() + + var buf bytes.Buffer + log.SetOutput(&buf) + defer func() { + log.SetOutput(os.Stderr) + }() + + cfg := config.Configuration{ + Server: config.Server{ + Port: config.Port{ + HTTP: "80", + HTTPS: "443", + }, + Upstream: config.Upstream{ + Host: "www.google.com", + Scheme: "https", + Endpoints: []string{}, + }, + }, + } + config.Config = config.Configuration{ Log: config.Log{ TimeFormat: "2006/01/02 15:04:05", Format: `$host - $remote_addr - $remote_user $protocol $request_method "$request" $status $body_bytes_sent "$http_referer" "$http_user_agent" $cached_status $cached_status_label`, }, } - logger.LogSetup(config.Config.Server) + logger.LogSetup(cfg.Server) - expectedOut := `time=" " level=info msg="Server will run on: 80 and 443\n"` + "\n" + - `time=" " level=info msg="Redirecting to url: https://www.google.com -> [1.2.3.4 8.8.8.8]\n"` + "\n" + expectedOut := `time=" " level=info msg="Server will run on :80 and :443 and redirects to url: https://www.google.com -> VOID\n"` + "\n" assert.Equal(t, expectedOut, buf.String()) tearDownLog() diff --git a/server/response/header.go b/server/response/header.go new file mode 100644 index 00000000..88baee30 --- /dev/null +++ b/server/response/header.go @@ -0,0 +1,25 @@ +package response + +// __ +// .-----.-----.______.-----.----.-----.--.--.--.--.______.----.---.-.----| |--.-----. +// | _ | _ |______| _ | _| _ |_ _| | |______| __| _ | __| | -__| +// |___ |_____| | __|__| |_____|__.__|___ | |____|___._|____|__|__|_____| +// |_____| |__| |_____| +// +// Copyright (c) 2020 Fabio Cicerchia. https://fabiocicerchia.it. MIT License +// Repo: https://github.com/fabiocicerchia/go-proxy-cache + +// CacheStatusHeader - HTTP Header for showing cache status. +const CacheStatusHeader = "X-Go-Proxy-Cache-Status" + +// CacheStatusHeader - HTTP Header for showing cache status. +const CacheBypassHeader = "X-Go-Proxy-Cache-Force-Fresh" + +// CacheStatusHeaderHit - Cache status HIT for HTTP Header X-Go-Proxy-Cache-Status. +const CacheStatusHeaderHit = "HIT" + +// CacheStatusHeaderMiss - Cache status MISS for HTTP Header X-Go-Proxy-Cache-Status. +const CacheStatusHeaderMiss = "MISS" + +// CacheStatusHeaderStale - Cache status STALE for HTTP Header X-Go-Proxy-Cache-Status. +const CacheStatusHeaderStale = "STALE" diff --git a/server/response/model.go b/server/response/model.go new file mode 100644 index 00000000..ff291208 --- /dev/null +++ b/server/response/model.go @@ -0,0 +1,32 @@ +package response + +// __ +// .-----.-----.______.-----.----.-----.--.--.--.--.______.----.---.-.----| |--.-----. +// | _ | _ |______| _ | _| _ |_ _| | |______| __| _ | __| | -__| +// |___ |_____| | __|__| |_____|__.__|___ | |____|___._|____|__|__|_____| +// |_____| |__| |_____| +// +// Copyright (c) 2020 Fabio Cicerchia. https://fabiocicerchia.it. MIT License +// Repo: https://github.com/fabiocicerchia/go-proxy-cache + +import ( + "github.com/fabiocicerchia/go-proxy-cache/utils/slice" +) + +type DataChunks [][]byte + +// Bytes - Returns flat slice of bytes. +func (dc DataChunks) Bytes() []byte { + bytes := []byte{} + + for _, c := range dc { + bytes = append(bytes, c...) + } + + return bytes +} + +// Len - Returns total length. +func (dc DataChunks) Len() int { + return slice.LenSliceBytes(([][]byte)(dc)) +} diff --git a/server/response/response.go b/server/response/response.go index fa50d90c..2c94aff4 100644 --- a/server/response/response.go +++ b/server/response/response.go @@ -11,36 +11,45 @@ package response import ( "bufio" + "compress/gzip" + "crypto/sha1" + "encoding/hex" "errors" + "fmt" + "hash" "net" "net/http" + + "github.com/go-http-utils/headers" ) var errHijackNotSupported = errors.New("hijack not supported") -// CacheStatusHeader - HTTP Header for showing cache status. -const CacheStatusHeader = "X-Go-Proxy-Cache-Status" - -// CacheStatusHeaderHit - Cache status HIT for HTTP Header X-Go-Proxy-Cache-Status. -const CacheStatusHeaderHit = "HIT" - -// CacheStatusHeaderMiss - Cache status MISS for HTTP Header X-Go-Proxy-Cache-Status. -const CacheStatusHeaderMiss = "MISS" - -// CacheStatusHeaderStale - Cache status STALE for HTTP Header X-Go-Proxy-Cache-Status. -const CacheStatusHeaderStale = "STALE" - // LoggedResponseWriter - Decorator for http.ResponseWriter. type LoggedResponseWriter struct { http.ResponseWriter http.Hijacker - StatusCode int - Content [][]byte + + ReqID string + statusCodeSent bool + StatusCode int + Content DataChunks + + // GZip + GZipResponse *gzip.Writer + + // ETag + hash hash.Hash + hashLen int } // NewLoggedResponseWriter - Creates new instance of ResponseWriter. -func NewLoggedResponseWriter(w http.ResponseWriter) *LoggedResponseWriter { - lwr := &LoggedResponseWriter{ResponseWriter: w} +func NewLoggedResponseWriter(w http.ResponseWriter, reqID string) *LoggedResponseWriter { + lwr := &LoggedResponseWriter{ + ReqID: reqID, + ResponseWriter: w, + hash: sha1.New(), + } lwr.Reset() return lwr @@ -59,21 +68,63 @@ func (lwr *LoggedResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { // Reset - Reset the stored content of LoggedResponseWriter. func (lwr *LoggedResponseWriter) Reset() { lwr.StatusCode = 0 - lwr.Content = make([][]byte, 0) + lwr.Content = make(DataChunks, 0) } // WriteHeader - ResponseWriter's WriteHeader method decorator. func (lwr *LoggedResponseWriter) WriteHeader(statusCode int) { + lwr.statusCodeSent = true lwr.StatusCode = statusCode + + // no sending to ResponseWriter as it is buffered either for ETag or GZip support. +} + +// ForceWriteHeader - Send statusCode right away. +func (lwr *LoggedResponseWriter) ForceWriteHeader(statusCode int) { + lwr.WriteHeader(statusCode) + lwr.ResponseWriter.WriteHeader(statusCode) } +// SendNotImplemented - Send 501 Not Implemented. +func (lwr *LoggedResponseWriter) SendNotImplemented() { + lwr.ForceWriteHeader(http.StatusNotImplemented) +} + // Write - ResponseWriter's Write method decorator. func (lwr *LoggedResponseWriter) Write(p []byte) (int, error) { + if !lwr.statusCodeSent && lwr.StatusCode == 0 { + lwr.GetLogger().Warning("No status code has been set before sending data, fallback on 200 OK.") + // This is exactly what Go would also do if it hasn't been written yet. + lwr.StatusCode = http.StatusOK + } + lwr.Content = append(lwr.Content, []byte{}) chunk := len(lwr.Content) - 1 lwr.Content[chunk] = append(lwr.Content[chunk], p...) + // gzip + if lwr.GZipResponse != nil { + if lwr.ResponseWriter.Header().Get(headers.ContentType) == "" { + // If no content type, apply sniffing algorithm to un-gzipped body. + lwr.ResponseWriter.Header().Set(headers.ContentType, http.DetectContentType(p)) + } + + lwr.GZipResponse.Write(p) + } + + // etag + l, err := lwr.hash.Write(p) + lwr.hashLen += l + + // no sending to ResponseWriter as it is buffered either for ETag or GZip support. + return l, err +} + +// ForceWrite - Send content right away. +func (lwr *LoggedResponseWriter) ForceWrite(p []byte) (int, error) { + lwr.Write(p) + return lwr.ResponseWriter.Write(p) } @@ -81,15 +132,77 @@ func (lwr *LoggedResponseWriter) Write(p []byte) (int, error) { func (lwr *LoggedResponseWriter) CopyHeaders(src http.Header) { for k, vv := range src { for _, v := range vv { - lwr.Header().Add(k, v) + lwr.ResponseWriter.Header().Add(k, v) } } } -// WriteBody - Sends the body to the client. +// WriteBody - Sends the body to the client (forced sent). func (lwr *LoggedResponseWriter) WriteBody(page string) bool { pageByte := []byte(page) sent, err := lwr.ResponseWriter.Write(pageByte) return sent > 0 && err == nil } + +// SendResponse - Write the Response. +func (lwr LoggedResponseWriter) SendResponse() { + // TODO: Get extra behaviour from ServeCachedResponse + lwr.ResponseWriter.WriteHeader(lwr.StatusCode) + + // Generate GZip. + // lwr.GZipResponse.Close() will write some data even if no data has been written. + // StatusNotModified and StatusNoContent shouldn't have a body, so no triggering Close(). + if lwr.GZipResponse != nil && lwr.StatusCode != http.StatusNotModified && lwr.StatusCode != http.StatusNoContent { + // In this way it'll write in a nested LoggedResponseWriter so it can + // catch the binary data. + lwr.GZipResponse.Close() + } + + // Serve content. + lwr.ResponseWriter.Write(lwr.Content.Bytes()) +} + +// ETAG ------------------------------------------------------------------------ + +// GetETag - Returns the ETag value. +func (lwr LoggedResponseWriter) GetETag(weak bool) string { + etagWeakPrefix := "" + if weak { + etagWeakPrefix = "W/" + } + + return fmt.Sprintf(`"%s%d-%s"`, etagWeakPrefix, lwr.hashLen, hex.EncodeToString(lwr.hash.Sum(nil))) +} + +// SetETag - Set the ETag HTTP Header. +func (lwr *LoggedResponseWriter) SetETag(weak bool) { + lwr.ResponseWriter.Header().Set(headers.ETag, lwr.GetETag(weak)) +} + +// MustServeOriginalResponse - Check whether an ETag could be added. +func (lwr LoggedResponseWriter) MustServeOriginalResponse(req *http.Request) bool { + lwr.GetLogger().Debugf("MustServerOriginalResponse - no hash has been computed (maybe no Write has been invoked): %v", lwr.hash == nil) + lwr.GetLogger().Debugf("MustServerOriginalResponse - there's already an ETag from upstream: %v (%s)", lwr.ResponseWriter.Header().Get(headers.ETag) != "", lwr.ResponseWriter.Header().Get(headers.ETag)) + lwr.GetLogger().Debugf("MustServerOriginalResponse - response is not successful (2xx): %v (%d)", (lwr.StatusCode < http.StatusOK || lwr.StatusCode >= http.StatusMultipleChoices), lwr.StatusCode) + lwr.GetLogger().Debugf("MustServerOriginalResponse - response is without content (204): %v", lwr.StatusCode == http.StatusNoContent) + lwr.GetLogger().Debugf("MustServerOriginalResponse - there is no buffered content (maybe no Write has been invoked): %v", len(lwr.Content) == 0) + + return lwr.hash == nil || // no hash has been computed (maybe no Write has been invoked) + lwr.ResponseWriter.Header().Get(headers.ETag) != "" || // there's already an ETag from upstream + (lwr.StatusCode < http.StatusOK || lwr.StatusCode >= http.StatusMultipleChoices) || // response is not successful (2xx) + lwr.StatusCode == http.StatusNoContent || // response is without content (204) + len(lwr.Content) == 0 // there is no buffered content (maybe no Write has been invoked) +} + +// SendNotModifiedResponse - Write the 304 Response. +func (lwr LoggedResponseWriter) SendNotModifiedResponse() { + lwr.ResponseWriter.WriteHeader(http.StatusNotModified) + lwr.ResponseWriter.Write(nil) +} + +// GZIP ------------------------------------------------------------------------ +func (lwr *LoggedResponseWriter) InitGZipBuffer() { + lwrGzip := &LoggedResponseWriter{ResponseWriter: lwr.ResponseWriter} + lwr.GZipResponse = gzip.NewWriter(lwrGzip) +} diff --git a/server/response/response_test.go b/server/response/response_test.go index 6b69f385..759890ee 100644 --- a/server/response/response_test.go +++ b/server/response/response_test.go @@ -12,23 +12,28 @@ package response_test // Repo: https://github.com/fabiocicerchia/go-proxy-cache import ( + "compress/gzip" "net/http" + "net/http/httptest" + "regexp" "testing" log "github.com/sirupsen/logrus" + "github.com/stretchr/testify/assert" "github.com/fabiocicerchia/go-proxy-cache/server/response" - "github.com/stretchr/testify/assert" ) var MockStatusCode int -var MockContent [][]byte +var MockContent response.DataChunks type ResponseWriterMock struct { http.ResponseWriter } -func (rwm ResponseWriterMock) WriteHeader(statusCode int) { MockStatusCode = statusCode } +func (rwm ResponseWriterMock) WriteHeader(statusCode int) { + MockStatusCode = statusCode +} func (rwm ResponseWriterMock) Write(p []byte) (int, error) { MockContent = append(MockContent, []byte{}) chunk := len(MockContent) - 1 @@ -49,9 +54,9 @@ func initLogs() { func TestNewWriter(t *testing.T) { initLogs() - var rwMock ResponseWriterMock + rwMock := ResponseWriterMock{} - lwr := response.NewLoggedResponseWriter(rwMock) + lwr := response.NewLoggedResponseWriter(rwMock, "TestNewWriter") assert.Equal(t, 0, lwr.StatusCode) assert.Len(t, lwr.Content, 0) @@ -62,15 +67,34 @@ func TestNewWriter(t *testing.T) { func TestCatchStatusCode(t *testing.T) { initLogs() - var rwMock ResponseWriterMock + rwMock := ResponseWriterMock{} - lwr := response.NewLoggedResponseWriter(rwMock) + lwr := response.NewLoggedResponseWriter(rwMock, "TestCatchStatusCode") lwr.WriteHeader(http.StatusCreated) // checks lwr assert.Equal(t, http.StatusCreated, lwr.StatusCode) assert.Len(t, lwr.Content, 0) + // verify calls on rwMock + assert.Equal(t, -1, MockStatusCode) + assert.Len(t, MockContent, 0) + + tearDownResponse() +} + +func TestCatchStatusCodeForced(t *testing.T) { + initLogs() + + rwMock := ResponseWriterMock{} + + lwr := response.NewLoggedResponseWriter(rwMock, "TestCatchStatusCodeForced") + lwr.ForceWriteHeader(http.StatusCreated) + + // checks lwr + assert.Equal(t, http.StatusCreated, lwr.StatusCode) + assert.Len(t, lwr.Content, 0) + // verify calls on rwMock assert.Equal(t, http.StatusCreated, MockStatusCode) assert.Len(t, MockContent, 0) @@ -81,23 +105,54 @@ func TestCatchStatusCode(t *testing.T) { func TestCatchContent(t *testing.T) { initLogs() - var rwMock ResponseWriterMock + rwMock := ResponseWriterMock{} - lwr := response.NewLoggedResponseWriter(rwMock) + lwr := response.NewLoggedResponseWriter(rwMock, "TestCatchContent") content := []byte("test content") _, err := lwr.Write(content) assert.Nil(t, err) - expectedContent := [][]byte{content} + expectedContent := response.DataChunks{content} // checks lwr - assert.Equal(t, 0, lwr.StatusCode) + // even if don't set it explicitly, it fallback on 200 + assert.Equal(t, http.StatusOK, lwr.StatusCode) + assert.Equal(t, expectedContent, lwr.Content) + + // verify calls on rwMock + assert.Equal(t, -1, MockStatusCode) + // Empty because buffered. + assert.Equal(t, 0, MockContent.Len()) + assert.Equal(t, []byte{}, MockContent.Bytes()) + assert.Equal(t, response.DataChunks{}, MockContent) + + tearDownResponse() +} + +func TestCatchContentForced(t *testing.T) { + initLogs() + + rwMock := ResponseWriterMock{} + + lwr := response.NewLoggedResponseWriter(rwMock, "TestCatchContentForced") + + content := []byte("test content") + _, err := lwr.ForceWrite(content) + assert.Nil(t, err) + + expectedContent := response.DataChunks{content} + + // checks lwr + // even if don't set it explicitly, it fallback on 200 + assert.Equal(t, http.StatusOK, lwr.StatusCode) assert.Equal(t, expectedContent, lwr.Content) // verify calls on rwMock assert.Equal(t, -1, MockStatusCode) assert.Equal(t, expectedContent, MockContent) + assert.Equal(t, 12, MockContent.Len()) + assert.Equal(t, content, MockContent.Bytes()) tearDownResponse() } @@ -105,34 +160,131 @@ func TestCatchContent(t *testing.T) { func TestCatchContentThreeChunks(t *testing.T) { initLogs() - var rwMock ResponseWriterMock + rwMock := ResponseWriterMock{} - lwr := response.NewLoggedResponseWriter(rwMock) + lwr := response.NewLoggedResponseWriter(rwMock, "TestCatchContentThreeChunks") content := []byte("test content") content2 := []byte("test content2") content3 := []byte("test content3") - _, err := lwr.Write(content) + _, err := lwr.ForceWrite(content) assert.Nil(t, err) - _, err = lwr.Write(content2) + _, err = lwr.ForceWrite(content2) assert.Nil(t, err) - _, err = lwr.Write(content3) + _, err = lwr.ForceWrite(content3) assert.Nil(t, err) - expectedContent := [][]byte{content, content2, content3} + expectedContent := response.DataChunks{content, content2, content3} // checks lwr - assert.Equal(t, 0, lwr.StatusCode) + // even if don't set it explicitly, it fallback on 200 + assert.Equal(t, http.StatusOK, lwr.StatusCode) assert.Equal(t, expectedContent, lwr.Content) // verify calls on rwMock assert.Equal(t, -1, MockStatusCode) assert.Equal(t, expectedContent, MockContent) + assert.Equal(t, 38, MockContent.Len()) tearDownResponse() } +func TestSendNotImplemented(t *testing.T) { + initLogs() + + rwMock := ResponseWriterMock{} + + lwr := response.NewLoggedResponseWriter(rwMock, "TestSendNotImplemented") + lwr.SendNotImplemented() + + // checks lwr + assert.Equal(t, http.StatusNotImplemented, lwr.StatusCode) + + // verify calls on rwMock + assert.Equal(t, http.StatusNotImplemented, MockStatusCode) + + tearDownResponse() +} + +func TestSendNotModifiedResponse(t *testing.T) { + initLogs() + + rwMock := ResponseWriterMock{} + + lwr := response.NewLoggedResponseWriter(rwMock, "TestSendNotModifiedResponse") + lwr.SendNotModifiedResponse() + + // checks lwr + // it sends only to the internal writer + assert.Equal(t, 0, lwr.StatusCode) + assert.Equal(t, response.DataChunks{}, lwr.Content) + + // verify calls on rwMock + assert.Equal(t, http.StatusNotModified, MockStatusCode) + assert.Equal(t, response.DataChunks{[]byte{}}, MockContent) + + tearDownResponse() +} + +func TestGetETagWeak(t *testing.T) { + initLogs() + + rwMock := ResponseWriterMock{} + + lwr := response.NewLoggedResponseWriter(rwMock, "TestGetETagWeak") + + etag := lwr.GetETag(true) + + assert.Regexp(t, regexp.MustCompile(`^\"W/[0-9]+-[0-9a-f]{40}\"$`), etag) +} + +func TestGetETagNotWeak(t *testing.T) { + initLogs() + + rwMock := ResponseWriterMock{} + + lwr := response.NewLoggedResponseWriter(rwMock, "TestGetETagNotWeak") + + etag := lwr.GetETag(false) + + assert.Regexp(t, regexp.MustCompile(`^\"[0-9]+-[0-9a-f]{40}\"$`), etag) +} + +func TestSetETagWeak(t *testing.T) { + initLogs() + + rwMock := ResponseWriterMock{ResponseWriter: httptest.NewRecorder()} + + lwr := response.NewLoggedResponseWriter(rwMock, "TestSetETagWeak") + lwr.SetETag(true) + + assert.Regexp(t, regexp.MustCompile(`^\"W/[0-9]+-[0-9a-f]{40}\"$`), lwr.ResponseWriter.Header().Get("ETag")) +} + +func TestSetETagNotWeak(t *testing.T) { + initLogs() + + rwMock := ResponseWriterMock{ResponseWriter: httptest.NewRecorder()} + + lwr := response.NewLoggedResponseWriter(rwMock, "TestSetETagNotWeak") + lwr.SetETag(false) + + assert.Regexp(t, regexp.MustCompile(`^\"[0-9]+-[0-9a-f]{40}\"$`), lwr.ResponseWriter.Header().Get("ETag")) +} + +func TestInitGZipBuffer(t *testing.T) { + initLogs() + + rwMock := ResponseWriterMock{ResponseWriter: httptest.NewRecorder()} + + lwr := response.NewLoggedResponseWriter(rwMock, "TestInitGZipBuffer") + lwr.InitGZipBuffer() + + assert.NotNil(t, lwr.GZipResponse) + assert.IsType(t, &gzip.Writer{}, lwr.GZipResponse) +} + func tearDownResponse() { MockStatusCode = -1 - MockContent = make([][]byte, 0) + MockContent = make(response.DataChunks, 0) } diff --git a/server/response/utils.go b/server/response/utils.go new file mode 100644 index 00000000..53c55f8a --- /dev/null +++ b/server/response/utils.go @@ -0,0 +1,21 @@ +package response + +// __ +// .-----.-----.______.-----.----.-----.--.--.--.--.______.----.---.-.----| |--.-----. +// | _ | _ |______| _ | _| _ |_ _| | |______| __| _ | __| | -__| +// |___ |_____| | __|__| |_____|__.__|___ | |____|___._|____|__|__|_____| +// |_____| |__| |_____| +// +// Copyright (c) 2020 Fabio Cicerchia. https://fabiocicerchia.it. MIT License +// Repo: https://github.com/fabiocicerchia/go-proxy-cache + +import ( + log "github.com/sirupsen/logrus" +) + +// LoggedResponseWriter - Decorator for http.ResponseWriter. +func (lwr LoggedResponseWriter) GetLogger() *log.Entry { + return log.WithFields(log.Fields{ + "ReqID": lwr.ReqID, + }) +} diff --git a/server/server.go b/server/server.go index db1d2fc6..467e115d 100644 --- a/server/server.go +++ b/server/server.go @@ -17,7 +17,6 @@ import ( "syscall" "time" - "github.com/NYTimes/gziphandler" log "github.com/sirupsen/logrus" "github.com/fabiocicerchia/go-proxy-cache/cache/engine" @@ -35,10 +34,15 @@ const enableTimeoutHandler = true // DefaultTimeoutShutdown - Default Timeout for shutting down a context. const DefaultTimeoutShutdown time.Duration = 5 * time.Second +type Server struct { + Domain string + HttpSrv http.Server +} + // Servers - Contains the HTTP/HTTPS servers. type Servers struct { - HTTP map[string]*http.Server - HTTPS map[string]*http.Server + HTTP map[string]Server + HTTPS map[string]Server } // Run - Starts the GoProxyCache servers' listeners. @@ -50,8 +54,8 @@ func Run(configFile string) { config.Print() servers := &Servers{ - HTTP: make(map[string]*http.Server), - HTTPS: make(map[string]*http.Server), + HTTP: make(map[string]Server), + HTTPS: make(map[string]Server), } for _, domain := range config.GetDomains() { @@ -87,31 +91,19 @@ func Run(configFile string) { } // InitServer - Generates the http.Server configuration. -func InitServer(domain string) *http.Server { +func InitServer(domain string, domainConfig config.Configuration) http.Server { mux := http.NewServeMux() // handlers - // NOTE: THIS IS FOR EVERY DOMAIN, NO DOMAIN OVERRIDE. - // WHEN SHARING SAME PORT NO CUSTOM OVERRIDES ON CRITICAL SETTINGS. - if config.Config.Server.Healthcheck { - mux.HandleFunc("/healthcheck", handler.HandleHealthcheck(config.Config)) + if domainConfig.Server.Healthcheck { + mux.HandleFunc("/healthcheck", handler.HandleHealthcheck(domainConfig)) } - mux.HandleFunc("/", handler.HandleRequest(config.Config)) + mux.HandleFunc("/", handler.HandleRequest) // basic var muxMiddleware http.Handler = mux - // etag middleware - muxMiddleware = ConditionalETag(muxMiddleware) - - // gzip middleware - // NOTE: THIS IS FOR EVERY DOMAIN, NO DOMAIN OVERRIDE. - // WHEN SHARING SAME PORT NO CUSTOM OVERRIDES ON CRITICAL SETTINGS. - if config.Config.Server.GZip { - muxMiddleware = gziphandler.GzipHandler(muxMiddleware) - } - // timeout middleware // NOTE: THIS IS FOR EVERY DOMAIN, NO DOMAIN OVERRIDE. // WHEN SHARING SAME PORT NO CUSTOM OVERRIDES ON CRITICAL SETTINGS. @@ -120,7 +112,7 @@ func InitServer(domain string) *http.Server { muxMiddleware = http.TimeoutHandler(muxMiddleware, timeout.Handler, "Timed Out\n") } - server := &http.Server{ + server := http.Server{ ReadTimeout: timeout.Read * time.Second, WriteTimeout: timeout.Write * time.Second, IdleTimeout: timeout.Idle * time.Second, @@ -135,28 +127,26 @@ func InitServer(domain string) *http.Server { // NOTE: There will be only ONE server listening on a port. // This means the last processed will override all the previous shared // settings. THIS COULD LEAD TO CONFLICTS WHEN SHARING THE SAME PORT. -func (s *Servers) AttachPlain(port string, server *http.Server) { - s.HTTP[port] = server - s.HTTP[port].Addr = ":" + port +func (s *Servers) AttachPlain(domain string, port string, server http.Server) { + s.HTTP[port] = Server{Domain: domain, HttpSrv: server} } // AttachSecure - Adds a new HTTPS server in the listener container. // NOTE: There will be only ONE server listening on a port. // This means the last processed will override all the previous shared // settings. THIS COULD LEAD TO CONFLICTS WHEN SHARING THE SAME PORT. -func (s *Servers) AttachSecure(port string, server *http.Server) { - s.HTTPS[port] = server - s.HTTPS[port].Addr = ":" + port +func (s *Servers) AttachSecure(domain string, port string, server http.Server) { + s.HTTPS[port] = Server{Domain: domain, HttpSrv: server} } // InitServers - Returns a http.Server configuration for HTTP and HTTPS. -func (s *Servers) InitServers(domain string, domainConfig config.Server) { - srv := InitServer(domain) - s.AttachPlain(domainConfig.Port.HTTP, srv) +func (s *Servers) InitServers(domain string, domainConfig config.Configuration) { + srvHTTP := InitServer(domain, domainConfig) + s.AttachPlain(domain, domainConfig.Server.Port.HTTP, srvHTTP) - srvHTTPS := InitServer(domain) + srvHTTPS := InitServer(domain, domainConfig) - err := srvtls.ServerOverrides(domain, srvHTTPS, domainConfig) + err := srvtls.ServerOverrides(domain, &srvHTTPS, domainConfig.Server) if err != nil { log.Errorf("Skipping '%s' TLS server configuration: %s", domain, err) log.Errorf("No HTTPS server will be listening on '%s'", domain) @@ -164,53 +154,61 @@ func (s *Servers) InitServers(domain string, domainConfig config.Server) { return } - s.AttachSecure(domainConfig.Port.HTTPS, srvHTTPS) + s.AttachSecure(domain, domainConfig.Server.Port.HTTPS, srvHTTPS) } // StartDomainServer - Configures and start listening for a particular domain. func (s *Servers) StartDomainServer(domain string, scheme string) { - domainConfig := config.DomainConf(domain, scheme) - if domainConfig == nil { + domainConfig, found := config.DomainConf(domain, scheme) + if !found { log.Errorf("Missing configuration for %s.", domain) return } domainID := domainConfig.Server.Upstream.GetDomainID() - // redis connect - circuitbreaker.InitCircuitBreaker(domainID, domainConfig.CircuitBreaker) - engine.InitConn(domainID, domainConfig.Cache) - // Log setup values logger.LogSetup(domainConfig.Server) + // redis connect + circuitbreaker.InitCircuitBreaker(domainID, domainConfig.CircuitBreaker) + engine.InitConn(domainID, domainConfig.Cache, log.StandardLogger()) + // config server http & https - s.InitServers(domain, domainConfig.Server) + s.InitServers(domain, domainConfig) // lb balancer.InitRoundRobin(domainID, domainConfig.Server.Upstream.Endpoints) } func (s Servers) startListeners() { - for _, srvHTTP := range s.HTTP { - go func(srv *http.Server) { log.Fatal(srv.ListenAndServe()) }(srvHTTP) + for port, srvHTTP := range s.HTTP { + srvHTTP.HttpSrv.Addr = ":" + port + + go func(srv http.Server) { + log.Fatal(srv.ListenAndServe()) + }(srvHTTP.HttpSrv) } - for _, srvHTTPS := range s.HTTPS { - go func(srv *http.Server) { log.Fatal(srv.ListenAndServeTLS("", "")) }(srvHTTPS) + for port, srvHTTPS := range s.HTTPS { + srvHTTPS.HttpSrv.Addr = ":" + port + + go func(srv http.Server) { + log.Fatal(srv.ListenAndServeTLS("", "")) + }(srvHTTPS.HttpSrv) } } func (s Servers) shutdownServers(ctx context.Context) { for k, v := range s.HTTP { - err := v.Shutdown(ctx) + err := v.HttpSrv.Shutdown(ctx) if err != nil { log.Fatalf("Cannot shutdown server %s: %s", k, err) } } for k, v := range s.HTTPS { - err := v.Shutdown(ctx) + err := v.HttpSrv.Shutdown(ctx) if err != nil { log.Fatalf("Cannot shutdown server %s: %s", k, err) } diff --git a/server/storage/storage.go b/server/storage/storage.go index 15838a41..42166529 100644 --- a/server/storage/storage.go +++ b/server/storage/storage.go @@ -22,19 +22,20 @@ import ( // RequestCallDTO - DTO object containing request and response. type RequestCallDTO struct { + ReqID string Response response.LoggedResponseWriter Request http.Request CacheObject cache.Object } // RetrieveCachedContent - Retrives the cached response. -func RetrieveCachedContent(rc RequestCallDTO) (cache.URIObj, error) { +func RetrieveCachedContent(rc RequestCallDTO, logger *log.Entry) (cache.URIObj, error) { err := rc.CacheObject.RetrieveFullPage() if err != nil { if err == cache.ErrEmptyValue { - log.Infof("Cannot retrieve page %s: %s\n", rc.CacheObject.CurrentURIObject.URL.String(), err) + logger.Infof("Cannot retrieve page %s: %s\n", rc.CacheObject.CurrentURIObject.URL.String(), err) } else { - log.Warnf("Cannot retrieve page %s: %s\n", rc.CacheObject.CurrentURIObject.URL.String(), err) + logger.Warnf("Cannot retrieve page %s: %s\n", rc.CacheObject.CurrentURIObject.URL.String(), err) } return cache.URIObj{}, err @@ -50,8 +51,9 @@ func RetrieveCachedContent(rc RequestCallDTO) (cache.URIObj, error) { // StoreGeneratedPage - Stores a response in the cache. func StoreGeneratedPage(rc RequestCallDTO, domainConfigCache config.Cache) (bool, error) { - ttl := ttl.GetTTL(rc.Response.Header(), domainConfigCache.TTL) - done, err := rc.CacheObject.StoreFullPage(ttl) + // Use the static rc.CacheObject.CurrentURIObject.ResponseHeaders to avoid data race + currentTTL := ttl.GetTTL(rc.CacheObject.CurrentURIObject.ResponseHeaders, domainConfigCache.TTL) + done, err := rc.CacheObject.StoreFullPage(currentTTL) return done, err } diff --git a/server/tls/tls.go b/server/tls/tls.go index 78a0daa7..d61d67f0 100644 --- a/server/tls/tls.go +++ b/server/tls/tls.go @@ -15,7 +15,6 @@ import ( "net/http" "github.com/pkg/errors" - log "github.com/sirupsen/logrus" "github.com/fabiocicerchia/go-proxy-cache/config" @@ -87,14 +86,6 @@ func Config(domain string, domainConfigTLS config.TLS) (*crypto_tls.Config, erro tlsConfig.Certificates = append(tlsConfig.Certificates, *c) } - // TODO: THIS COULD LEAD TO CONFLICTS WHEN SHARING THE SAME PORT. - if domainConfigTLS.Override != nil { - tlsConfig.CurvePreferences = domainConfigTLS.Override.CurvePreferences - tlsConfig.MinVersion = domainConfigTLS.Override.MinVersion - tlsConfig.MaxVersion = domainConfigTLS.Override.MaxVersion - tlsConfig.CipherSuites = domainConfigTLS.Override.CipherSuites - } - return tlsConfig, nil } diff --git a/server/transport/http.go b/server/transport/http.go index b7d45f33..73cb85e2 100644 --- a/server/transport/http.go +++ b/server/transport/http.go @@ -81,7 +81,7 @@ func ServeCachedResponse(ctx context.Context, lwr *response.LoggedResponseWriter PushProxiedResources(lwr, &uriobj) - handleBody(lwr, uriobj.Content) + handleBody(lwr.ResponseWriter, uriobj.Content) handleTrailer(announcedTrailers, lwr, res) } @@ -113,8 +113,8 @@ func handleHeaders(lwr *response.LoggedResponseWriter, res http.Response) int { return announcedTrailers } -func handleBody(lwr *response.LoggedResponseWriter, content [][]byte) { - copyResponse(lwr, content) +func handleBody(res http.ResponseWriter, content [][]byte) { + copyResponse(res, content) } func handleTrailer(announcedTrailers int, lwr *response.LoggedResponseWriter, res http.Response) { diff --git a/server/transport/http2_test.go b/server/transport/http2_test.go index bb031fcb..84c00b0e 100644 --- a/server/transport/http2_test.go +++ b/server/transport/http2_test.go @@ -14,9 +14,9 @@ package transport_test import ( "testing" - "github.com/fabiocicerchia/go-proxy-cache/server/transport" - "github.com/stretchr/testify/assert" + + "github.com/fabiocicerchia/go-proxy-cache/server/transport" ) func TestParseSimple(t *testing.T) { diff --git a/test/etag_test.go b/test/etag_test.go index 2a9ffa7d..8722f454 100644 --- a/test/etag_test.go +++ b/test/etag_test.go @@ -25,6 +25,10 @@ func TestETagValidResponse(t *testing.T) { client := &http.Client{} req, err := http.NewRequest("GET", "http://testing.local:50080/", nil) + // Need to fetch fresh content to verify the ETag. + req.Header = http.Header{ + "X-Go-Proxy-Cache-Force-Fresh": []string{"1"}, + } assert.Nil(t, err) req.Host = "www.w3.org" res, err := client.Do(req) @@ -40,6 +44,7 @@ func TestETagValidResponse(t *testing.T) { res.Body.Close() assert.Equal(t, "MISS", res.Header.Get("X-Go-Proxy-Cache-Status")) + // this is the real ETag from w3.org assert.Regexp(t, regexp.MustCompile(`^\"[0-9a-f]{4}-[0-9a-f]{13};[0-9a-f]{2}-[0-9a-f]{13}-gzip\"$`), res.Header.Get("ETag")) assert.Equal(t, "HTTP/1.1", res.Proto) @@ -50,6 +55,8 @@ func TestETagValidResponse(t *testing.T) { assert.Contains(t, string(body), "World Wide Web Consortium (W3C)`) assert.Contains(t, string(body), "\n\n") + + tearDownETag() } func TestETagIfModifiedSinceWhenChanged(t *testing.T) { @@ -209,3 +216,10 @@ func TestETagIfMatchAsMatch(t *testing.T) { func TestETagIfMatchAsNotMatch(t *testing.T) { t.Skip("Need to be implemented.") } + +func tearDownETag() { + req, _ := http.NewRequest("PURGE", "http://testing.local:50080/", nil) + req.Host = "www.w3.org" + client := &http.Client{} + _, _ = client.Do(req) +} diff --git a/test/full-setup/Dockerfile.nginx b/test/full-setup/Dockerfile.nginx index b3528917..0bee3992 100644 --- a/test/full-setup/Dockerfile.nginx +++ b/test/full-setup/Dockerfile.nginx @@ -1,5 +1,5 @@ # TODO: code smell, should be configured in github actions -FROM nginx:1.19.5-alpine +FROM nginx:1.21.1-alpine COPY nginx/vhost.conf /etc/nginx/conf.d/vhost.conf COPY certs /certs diff --git a/test/full-setup/Dockerfile.node b/test/full-setup/Dockerfile.node index ac27c914..732054f6 100644 --- a/test/full-setup/Dockerfile.node +++ b/test/full-setup/Dockerfile.node @@ -1,5 +1,5 @@ # TODO: code smell, should be configured in github actions -FROM node:15.4.0-alpine3.10 +FROM node:16.8.0-alpine3.14 WORKDIR /home/node/app diff --git a/test/full-setup/config.no-docker.yml b/test/full-setup/config.no-docker.yml deleted file mode 100644 index e1ae3ba8..00000000 --- a/test/full-setup/config.no-docker.yml +++ /dev/null @@ -1,82 +0,0 @@ -# TESTING PURPOSES ONLY -# TODO: code smell, should be only one file -################################################################################ -server: - port: - http: "50080" - https: "50443" - tls: - cert_file: certs/default/server.pem - key_file: certs/default/server.key - timeout: - read: 5s - read_header: 2s - write: 5s - idle: 20s - handler: -1 - -cache: - host: redis - -domains: - example_com: - server: - upstream: - host: example.com - - example_org: - server: - upstream: - host: example.org - - www_w3_org: - server: - upstream: - host: www.w3.org - endpoints: - - www.w3.org - tls: - cert_file: certs/www.w3.org/server.pem - key_file: certs/www.w3.org/server.key - - www_testing_local: - server: - upstream: - host: www.testing.local - scheme: http - endpoints: - - 127.0.0.1:40080 - - testing_local: - server: - upstream: - host: testing.local - scheme: http - endpoints: - - 127.0.0.1:40080 - - testing_local_https: - server: - upstream: - host: testing.local - scheme: https - insecure_bridge: true - endpoints: - - 127.0.0.1:40443 - - testing_local_ws: - server: - upstream: - host: testing.local - scheme: ws - endpoints: - - 127.0.0.1:40081 - - testing_local_wss: - server: - upstream: - host: testing.local - scheme: wss - insecure_bridge: true - endpoints: - - 127.0.0.1:40082 diff --git a/test/http2_test.go b/test/http2_test.go index c365b4ff..64d2c1cb 100644 --- a/test/http2_test.go +++ b/test/http2_test.go @@ -58,6 +58,8 @@ func TestHTTP2ClientCall(t *testing.T) { assert.Contains(t, string(body), "World Wide Web Consortium (W3C)`) assert.Contains(t, string(body), "\n\n") + + tearDownHttp2() } func TestHTTP2ClientCallToMissingDomain(t *testing.T) { @@ -85,3 +87,18 @@ func TestHTTP2ClientCallToMissingDomain(t *testing.T) { assert.Equal(t, http.StatusNotImplemented, res.StatusCode) } + +func tearDownHttp2() { + req, _ := http.NewRequest("PURGE", "https://testing.local:50443/", nil) + req.Host = "www.w3.org" + client := &http.Client{ + Transport: &http2.Transport{ + TLSClientConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + DisableCompression: true, + AllowHTTP: false, + }, + } + _, _ = client.Do(req) +} diff --git a/test/http_test.go b/test/http_test.go index 32388535..7fa11826 100644 --- a/test/http_test.go +++ b/test/http_test.go @@ -48,6 +48,8 @@ func TestHTTPClientCall(t *testing.T) { assert.Contains(t, string(body), "About W3C`) assert.Contains(t, string(body), "\n") + + tearDownHttp() } func TestHTTPClientCallToMissingDomain(t *testing.T) { @@ -67,3 +69,10 @@ func TestHTTPClientCallToMissingDomain(t *testing.T) { assert.Equal(t, http.StatusNotImplemented, res.StatusCode) } + +func tearDownHttp() { + req, _ := http.NewRequest("PURGE", "http://testing.local:50080/Consortium/", nil) + req.Host = "www.w3.org" + client := &http.Client{} + _, _ = client.Do(req) +} diff --git a/test/https_test.go b/test/https_test.go index d59400bb..e46f73c0 100644 --- a/test/https_test.go +++ b/test/https_test.go @@ -55,6 +55,8 @@ func TestHTTPSClientCall(t *testing.T) { assert.Contains(t, string(body), "World Wide Web Consortium (W3C)`) assert.Contains(t, string(body), "\n\n") + + tearDownHttps() } func TestHTTPSClientCallToMissingDomain(t *testing.T) { @@ -80,3 +82,16 @@ func TestHTTPSClientCallToMissingDomain(t *testing.T) { assert.Equal(t, http.StatusNotImplemented, res.StatusCode) } + +func tearDownHttps() { + req, _ := http.NewRequest("PURGE", "https://testing.local:50443/", nil) + req.Host = "www.w3.org" + client := &http.Client{ + Transport: &http.Transport{ + TLSClientConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + }, + } + _, _ = client.Do(req) +} diff --git a/utils/base64/base64_test.go b/utils/base64/base64_test.go index 696a27d7..f1a172b1 100644 --- a/utils/base64/base64_test.go +++ b/utils/base64/base64_test.go @@ -14,8 +14,9 @@ package base64_test import ( "testing" - "github.com/fabiocicerchia/go-proxy-cache/utils/base64" "github.com/stretchr/testify/assert" + + "github.com/fabiocicerchia/go-proxy-cache/utils/base64" ) func TestEncodeDecode(t *testing.T) { diff --git a/utils/circuit-breaker/circuit-breaker.go b/utils/circuit-breaker/circuit-breaker.go index 7688e83c..4e1c4cff 100644 --- a/utils/circuit-breaker/circuit-breaker.go +++ b/utils/circuit-breaker/circuit-breaker.go @@ -50,7 +50,7 @@ func cbReadyToTrip(config CircuitBreaker) func(counts gobreaker.Counts) bool { } func cbOnStateChange(name string, from gobreaker.State, to gobreaker.State) { - log.Warnf("Circuit Breaker - Changed from %s to %s", from.String(), to.String()) + log.Warnf("Circuit Breaker '%s' - Changed from %s to %s", name, from.String(), to.String()) } // CB - Returns instance of gobreaker.CircuitBreaker. diff --git a/utils/convert/convert_test.go b/utils/convert/convert_test.go index 61a711e4..6b362eb6 100644 --- a/utils/convert/convert_test.go +++ b/utils/convert/convert_test.go @@ -15,8 +15,9 @@ import ( "testing" "time" - "github.com/fabiocicerchia/go-proxy-cache/utils/convert" "github.com/stretchr/testify/assert" + + "github.com/fabiocicerchia/go-proxy-cache/utils/convert" ) // --- ToDuration diff --git a/utils/msgpack/msgpack_test.go b/utils/msgpack/msgpack_test.go index bbe0a652..3f3e60ff 100644 --- a/utils/msgpack/msgpack_test.go +++ b/utils/msgpack/msgpack_test.go @@ -14,8 +14,9 @@ package msgpack_test import ( "testing" - "github.com/fabiocicerchia/go-proxy-cache/utils/msgpack" "github.com/stretchr/testify/assert" + + "github.com/fabiocicerchia/go-proxy-cache/utils/msgpack" ) func TestEncodeDecode(t *testing.T) { diff --git a/utils/queue/queue.go b/utils/queue/queue.go index 93edc201..c81cd9d4 100644 --- a/utils/queue/queue.go +++ b/utils/queue/queue.go @@ -1,20 +1,33 @@ package queue +// __ +// .-----.-----.______.-----.----.-----.--.--.--.--.______.----.---.-.----| |--.-----. +// | _ | _ |______| _ | _| _ |_ _| | |______| __| _ | __| | -__| +// |___ |_____| | __|__| |_____|__.__|___ | |____|___._|____|__|__|_____| +// |_____| |__| |_____| +// +// Copyright (c) 2020 Fabio Cicerchia. https://fabiocicerchia.it. MIT License +// Repo: https://github.com/fabiocicerchia/go-proxy-cache + import ( "time" "github.com/sdeoras/dispatcher" ) +// MaxConcurrency - How many concurrent functions can be executed. // TODO: Make it customizable? const MaxConcurrency = 10 +// Dispatcher - Global queue dispatcher. var Dispatcher dispatcher.Dispatcher +// Init - Init a new dispatcher. func Init() { Dispatcher = dispatcher.New(MaxConcurrency) } +// WaitForCompletion - Waits all functions are completed. func WaitForCompletion() { for Dispatcher.IsRunning() { time.Sleep(time.Second) diff --git a/utils/string/string.go b/utils/scheme/scheme.go similarity index 98% rename from utils/string/string.go rename to utils/scheme/scheme.go index e7876e75..eba9f0a1 100644 --- a/utils/string/string.go +++ b/utils/scheme/scheme.go @@ -1,4 +1,4 @@ -package string +package scheme // __ // .-----.-----.______.-----.----.-----.--.--.--.--.______.----.---.-.----| |--.-----. diff --git a/utils/scheme/scheme_test.go b/utils/scheme/scheme_test.go new file mode 100644 index 00000000..577ca585 --- /dev/null +++ b/utils/scheme/scheme_test.go @@ -0,0 +1,42 @@ +package scheme_test + +// __ +// .-----.-----.______.-----.----.-----.--.--.--.--.______.----.---.-.----| |--.-----. +// | _ | _ |______| _ | _| _ |_ _| | |______| __| _ | __| | -__| +// |___ |_____| | __|__| |_____|__.__|___ | |____|___._|____|__|__|_____| +// |_____| |__| |_____| +// +// Copyright (c) 2020 Fabio Cicerchia. https://fabiocicerchia.it. MIT License +// Repo: https://github.com/fabiocicerchia/go-proxy-cache + +import ( + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/fabiocicerchia/go-proxy-cache/utils/scheme" +) + +func TestNormalizeHTTP(t *testing.T) { + assert.Equal(t, "http", scheme.NormalizeScheme("http")) + assert.Equal(t, "http", scheme.NormalizeScheme("HTTP")) + assert.Equal(t, "http", scheme.NormalizeScheme("HttP")) + assert.Equal(t, "http", scheme.NormalizeScheme("HTtp")) + assert.Equal(t, "http", scheme.NormalizeScheme("HtTp")) +} + +func TestNormalizeHTTPS(t *testing.T) { + assert.Equal(t, "https", scheme.NormalizeScheme("https")) + assert.Equal(t, "https", scheme.NormalizeScheme("HTTPS")) + assert.Equal(t, "https", scheme.NormalizeScheme("HttPs")) + assert.Equal(t, "https", scheme.NormalizeScheme("HTtps")) + assert.Equal(t, "https", scheme.NormalizeScheme("HtTpS")) +} + +func TestNormalizeNonExisting(t *testing.T) { + assert.Equal(t, "", scheme.NormalizeScheme("")) + assert.Equal(t, "", scheme.NormalizeScheme("1")) + assert.Equal(t, "", scheme.NormalizeScheme("-")) + assert.Equal(t, "", scheme.NormalizeScheme("qwerty")) + assert.Equal(t, "", scheme.NormalizeScheme("wss")) +} diff --git a/utils/slice/slice_test.go b/utils/slice/slice_test.go index 3f53a2da..cf8bcaf9 100644 --- a/utils/slice/slice_test.go +++ b/utils/slice/slice_test.go @@ -14,8 +14,9 @@ package slice_test import ( "testing" - "github.com/fabiocicerchia/go-proxy-cache/utils/slice" "github.com/stretchr/testify/assert" + + "github.com/fabiocicerchia/go-proxy-cache/utils/slice" ) // --- ContainsInt diff --git a/utils/string/string_test.go b/utils/string/string_test.go deleted file mode 100644 index 62dd2e54..00000000 --- a/utils/string/string_test.go +++ /dev/null @@ -1,42 +0,0 @@ -package string_test - -// __ -// .-----.-----.______.-----.----.-----.--.--.--.--.______.----.---.-.----| |--.-----. -// | _ | _ |______| _ | _| _ |_ _| | |______| __| _ | __| | -__| -// |___ |_____| | __|__| |_____|__.__|___ | |____|___._|____|__|__|_____| -// |_____| |__| |_____| -// -// Copyright (c) 2020 Fabio Cicerchia. https://fabiocicerchia.it. MIT License -// Repo: https://github.com/fabiocicerchia/go-proxy-cache - -import ( - "testing" - - "github.com/stretchr/testify/assert" - - utilsString "github.com/fabiocicerchia/go-proxy-cache/utils/string" -) - -func TestNormalizeHTTP(t *testing.T) { - assert.Equal(t, "http", utilsString.NormalizeScheme("http")) - assert.Equal(t, "http", utilsString.NormalizeScheme("HTTP")) - assert.Equal(t, "http", utilsString.NormalizeScheme("HttP")) - assert.Equal(t, "http", utilsString.NormalizeScheme("HTtp")) - assert.Equal(t, "http", utilsString.NormalizeScheme("HtTp")) -} - -func TestNormalizeHTTPS(t *testing.T) { - assert.Equal(t, "https", utilsString.NormalizeScheme("https")) - assert.Equal(t, "https", utilsString.NormalizeScheme("HTTPS")) - assert.Equal(t, "https", utilsString.NormalizeScheme("HttPs")) - assert.Equal(t, "https", utilsString.NormalizeScheme("HTtps")) - assert.Equal(t, "https", utilsString.NormalizeScheme("HtTpS")) -} - -func TestNormalizeNonExisting(t *testing.T) { - assert.Equal(t, "", utilsString.NormalizeScheme("")) - assert.Equal(t, "", utilsString.NormalizeScheme("1")) - assert.Equal(t, "", utilsString.NormalizeScheme("-")) - assert.Equal(t, "", utilsString.NormalizeScheme("qwerty")) - assert.Equal(t, "", utilsString.NormalizeScheme("wss")) -} diff --git a/utils/ttl/ttl_test.go b/utils/ttl/ttl_test.go index af0eb076..441f0764 100644 --- a/utils/ttl/ttl_test.go +++ b/utils/ttl/ttl_test.go @@ -16,8 +16,9 @@ import ( "testing" "time" - "github.com/fabiocicerchia/go-proxy-cache/utils/ttl" "github.com/stretchr/testify/assert" + + "github.com/fabiocicerchia/go-proxy-cache/utils/ttl" ) func TestGetTTLFromCacheControlWithMaxage(t *testing.T) { diff --git a/utils/utils_test.go b/utils/utils_test.go index 2ad6ca6c..fccb460d 100644 --- a/utils/utils_test.go +++ b/utils/utils_test.go @@ -16,9 +16,10 @@ import ( "testing" "time" + "github.com/stretchr/testify/assert" + "github.com/fabiocicerchia/go-proxy-cache/config" "github.com/fabiocicerchia/go-proxy-cache/utils" - "github.com/stretchr/testify/assert" ) // --- GetEnv