From d451fff49d511e5b4f1398b401193ab7291fb707 Mon Sep 17 00:00:00 2001 From: Evgeny Abramovich Date: Sun, 17 Sep 2023 14:55:41 -0300 Subject: [PATCH] Automatically reload the server after configuration changes (#22) * Move app logic to separate structure * Moved mapping normalisation in config loading * Optimised memory allocation at starting time * Auto restert draft * Refactor server stoppting * Updated roadmap * Fixed config watching * Added graceful shutdown * Cleanup code * Move server package functionality to uncors app * Fixed linting * Added tests for restarting * WIP: tests * WIP * Removed unused deps * Enables skipped tests * Fixed tests for GracefulShutdown * Added helpers.PanicInterceptor for config reloading --- .github/workflows/go.yml | 2 +- .gitignore | 1 - ROADMAP.md | 9 +- go.mod | 5 +- go.sum | 4 +- internal/config/config.go | 26 ++- internal/config/config_test.go | 69 ++++-- internal/config/helpers.go | 2 +- internal/helpers/graceful_shutdown.go | 43 ++++ .../graceful_shutdown_internal_test.go | 142 ++++++++++++ internal/helpers/panic_test.go | 3 +- internal/server/atomic_bool.go | 17 -- internal/server/atomic_bool_test.go | 53 ----- internal/server/server.go | 126 ---------- internal/server/server_test.go | 101 -------- internal/uncors/app.go | 156 +++++++++++++ internal/uncors/app_test.go | 217 ++++++++++++++++++ internal/uncors/handler.go | 88 +++++++ internal/uncors/listen.go | 63 +++++ internal/{ui => uncors}/loggers.go | 2 +- internal/{ui => uncors}/logo.go | 2 +- internal/{ui => uncors}/logo_test.go | 7 +- internal/{ui => uncors}/messages.go | 2 +- internal/uncors/shutdown.go | 34 +++ internal/version/new_version_check.go | 4 +- main.go | 151 +++--------- testing/mocks/closer_mock.go | 2 +- testing/mocks/http_client_mock.go | 2 +- testing/mocks/logger_mock.go | 27 +-- testing/mocks/replacer_factory_mock.go | 2 +- testing/mocks/writer_mock.go | 2 +- testing/testutils/certs.go | 20 +- 32 files changed, 891 insertions(+), 493 deletions(-) create mode 100644 internal/helpers/graceful_shutdown.go create mode 100644 internal/helpers/graceful_shutdown_internal_test.go delete mode 100644 internal/server/atomic_bool.go delete mode 100644 internal/server/atomic_bool_test.go delete mode 100644 internal/server/server.go delete mode 100644 internal/server/server_test.go create mode 100644 internal/uncors/app.go create mode 100644 internal/uncors/app_test.go create mode 100644 internal/uncors/handler.go create mode 100644 internal/uncors/listen.go rename internal/{ui => uncors}/loggers.go (97%) rename internal/{ui => uncors}/logo.go (98%) rename internal/{ui => uncors}/logo_test.go (97%) rename internal/{ui => uncors}/messages.go (95%) create mode 100644 internal/uncors/shutdown.go diff --git a/.github/workflows/go.yml b/.github/workflows/go.yml index 461d4019..a09cfbd5 100644 --- a/.github/workflows/go.yml +++ b/.github/workflows/go.yml @@ -31,7 +31,7 @@ jobs: run: go build -tags release -v . - name: Test - run: go test -tags release -v -coverprofile=coverage.out ./... + run: go test -tags release -timeout 1m -v -coverprofile=coverage.out ./... - name: SonarCloud Scan uses: SonarSource/sonarcloud-github-action@master diff --git a/.gitignore b/.gitignore index bfb61df4..279a7ff5 100644 --- a/.gitignore +++ b/.gitignore @@ -2,7 +2,6 @@ __debug_bin server.key server.crt dist/ -uncors .idea node_modules .uncors.yaml diff --git a/ROADMAP.md b/ROADMAP.md index bc6028d6..362ea96c 100644 --- a/ROADMAP.md +++ b/ROADMAP.md @@ -52,14 +52,15 @@ ## [0.1.0 Release](https://github.com/evg4b/uncors/releases/tag/v0.1.0) -- [X] Static file serving [PR](https://github.com/evg4b/uncors/pull/15) +- [X] Static file serving - [PR](https://github.com/evg4b/uncors/pull/15) - [X] Own error page for uncors internal errors -- [X] Separated mock for each url mapping [PR](https://github.com/evg4b/uncors/pull/16) +- [X] Separated mock for each url mapping - [PR](https://github.com/evg4b/uncors/pull/16) ## Next Release -- [X] Response caching [PR](https://github.com/evg4b/uncors/pull/17) -- [X] JSON Schema for config file [PR](https://github.com/evg4b/uncors/pull/19) +- [X] Response caching - [PR](https://github.com/evg4b/uncors/pull/17) +- [X] JSON Schema for config file - [PR](https://github.com/evg4b/uncors/pull/19) +- [ ] Automatically reload the server after configuration changes - [PR](https://github.com/evg4b/uncors/pull/22) ## Future features diff --git a/go.mod b/go.mod index 3735bf10..e702e243 100644 --- a/go.mod +++ b/go.mod @@ -8,13 +8,12 @@ require ( github.com/PuerkitoBio/purell v1.2.0 github.com/bmatcuk/doublestar/v4 v4.6.0 github.com/go-http-utils/headers v0.0.0-20181008091004-fed159eddc2a - github.com/go-playground/assert/v2 v2.2.0 github.com/go-playground/validator/v10 v10.15.4 github.com/gojuno/minimock/v3 v3.1.3 github.com/gorilla/mux v1.8.0 github.com/hashicorp/go-version v1.6.0 github.com/mitchellh/mapstructure v1.5.0 - github.com/pseidemann/finish v1.2.0 + github.com/phayes/freeport v0.0.0-20220201140144-74d24b5ae9f5 github.com/pterm/pterm v0.12.69 github.com/samber/lo v1.38.1 github.com/spf13/afero v1.9.5 @@ -30,7 +29,7 @@ require ( atomicgo.dev/schedule v0.1.0 // indirect github.com/containerd/console v1.0.3 // indirect github.com/davecgh/go-spew v1.1.1 // indirect - github.com/fsnotify/fsnotify v1.6.0 // indirect + github.com/fsnotify/fsnotify v1.6.0 github.com/gabriel-vasile/mimetype v1.4.2 // indirect github.com/go-playground/locales v0.14.1 // indirect github.com/go-playground/universal-translator v0.18.1 // indirect diff --git a/go.sum b/go.sum index 917adaf8..6c1bb27a 100644 --- a/go.sum +++ b/go.sum @@ -203,13 +203,13 @@ github.com/patrickmn/go-cache v2.1.0+incompatible h1:HRMgzkcYKYpi3C8ajMPV8OFXaaR github.com/patrickmn/go-cache v2.1.0+incompatible/go.mod h1:3Qf8kWWT7OJRJbdiICTKqZju1ZixQ/KpMGzzAfe6+WQ= github.com/pelletier/go-toml/v2 v2.1.0 h1:FnwAJ4oYMvbT/34k9zzHuZNrhlz48GB3/s6at6/MHO4= github.com/pelletier/go-toml/v2 v2.1.0/go.mod h1:tJU2Z3ZkXwnxa4DPO899bsyIoywizdUvyaeZurnPPDc= +github.com/phayes/freeport v0.0.0-20220201140144-74d24b5ae9f5 h1:Ii+DKncOVM8Cu1Hc+ETb5K+23HdAMvESYE3ZJ5b5cMI= +github.com/phayes/freeport v0.0.0-20220201140144-74d24b5ae9f5/go.mod h1:iIss55rKnNBTvrwdmkUpLnDpZoAHvWaiq5+iMmen4AE= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/sftp v1.13.1/go.mod h1:3HaPG6Dq1ILlpPZRO0HVMrsydcdLt6HRDccSgb87qRg= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= -github.com/pseidemann/finish v1.2.0 h1:XrEc9FCnBPulyM9NvAptAtcOCZZYHwV0MRCcnCfQlnw= -github.com/pseidemann/finish v1.2.0/go.mod h1:Wl17vXLhlT9a/K7jryhExgJPfbs4+dUpRaauEWt7oQ4= github.com/pterm/pterm v0.12.27/go.mod h1:PhQ89w4i95rhgE+xedAoqous6K9X+r6aSOI2eFF7DZI= github.com/pterm/pterm v0.12.29/go.mod h1:WI3qxgvoQFFGKGjGnJR849gU0TsEOvKn5Q8LlY1U7lg= github.com/pterm/pterm v0.12.30/go.mod h1:MOqLIyMOgmTDz9yorcYbcw+HsgoZo3BQfg2wtl3HEFE= diff --git a/internal/config/config.go b/internal/config/config.go index b999e1bf..f2184479 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -3,6 +3,8 @@ package config import ( "fmt" + "github.com/evg4b/uncors/internal/helpers" + "github.com/mitchellh/mapstructure" "github.com/spf13/pflag" "github.com/spf13/viper" @@ -13,6 +15,8 @@ const ( defaultHTTPSPort = 443 ) +var flags *pflag.FlagSet + type UncorsConfig struct { HTTPPort int `mapstructure:"http-port" validate:"required"` Mappings Mappings `mapstructure:"mappings" validate:"required"` @@ -29,7 +33,8 @@ func (c *UncorsConfig) IsHTTPSEnabled() bool { } func LoadConfiguration(viperInstance *viper.Viper, args []string) *UncorsConfig { - flags := defineFlags() + defineFlags() + helpers.AssertIsDefined(flags) if err := flags.Parse(args); err != nil { panic(fmt.Errorf("filed parsing flags: %w", err)) } @@ -61,14 +66,25 @@ func LoadConfiguration(viperInstance *viper.Viper, args []string) *UncorsConfig } if err := readURLMapping(viperInstance, configuration); err != nil { - panic(fmt.Errorf("recognize url mapping: %w", err)) + panic(err) + } + + configuration.Mappings = NormaliseMappings( + configuration.Mappings, + configuration.HTTPPort, + configuration.HTTPSPort, + configuration.IsHTTPSEnabled(), + ) + + if err := Validate(configuration); err != nil { + panic(err) } return configuration } -func defineFlags() *pflag.FlagSet { - flags := pflag.NewFlagSet("uncors", pflag.ContinueOnError) +func defineFlags() { + flags = pflag.NewFlagSet("uncors", pflag.ContinueOnError) flags.Usage = pflag.Usage flags.StringSliceP("to", "t", []string{}, "Target host with protocol for to the resource to be proxy") flags.StringSliceP("from", "f", []string{}, "Local host with protocol for to the resource from which proxying will take place") //nolint: lll @@ -79,6 +95,4 @@ func defineFlags() *pflag.FlagSet { flags.String("proxy", "", "HTTP/HTTPS proxy to provide requests to real server (used system by default)") flags.Bool("debug", false, "Show debug output") flags.StringP("config", "c", "", "Path to the configuration file") - - return flags } diff --git a/internal/config/config_test.go b/internal/config/config_test.go index 5d7093d2..eb21d09e 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -1,6 +1,8 @@ +// nolint: nosprintfhostport package config_test import ( + "fmt" "net/http" "testing" "time" @@ -76,15 +78,22 @@ mappings: ) func TestLoadConfiguration(t *testing.T) { - viperInstance := viper.New() - viperInstance.SetFs(testutils.FsFromMap(t, map[string]string{ + fs := testutils.FsFromMap(t, map[string]string{ corruptedConfigPath: corruptedConfig, fullConfigPath: fullConfig, incorrectConfigPath: incorrectConfig, minimalConfigPath: minimalConfig, - })) + }) t.Run("correctly parse config", func(t *testing.T) { + HTTPf := func(host string, port int) string { + return fmt.Sprintf("http://%s:%d", host, port) + } + + HTTPSf := func(host string, port int) string { + return fmt.Sprintf("https://%s:%d", host, port) + } + tests := []struct { name string args []string @@ -111,7 +120,7 @@ func TestLoadConfiguration(t *testing.T) { HTTPPort: 8080, HTTPSPort: 443, Mappings: config.Mappings{ - {From: testconstants.HTTPLocalhost, To: testconstants.HTTPSGithub}, + {From: testconstants.HTTPLocalhostWithPort(8080), To: testconstants.HTTPSGithub}, }, CacheConfig: config.CacheConfig{ ExpirationTime: config.DefaultExpirationTime, @@ -126,9 +135,9 @@ func TestLoadConfiguration(t *testing.T) { expected: &config.UncorsConfig{ HTTPPort: 8080, Mappings: config.Mappings{ - {From: testconstants.HTTPLocalhost, To: testconstants.HTTPSGithub}, + {From: testconstants.HTTPLocalhostWithPort(8080), To: testconstants.HTTPSGithub}, { - From: testconstants.HTTPLocalhost2, + From: testconstants.HTTPLocalhost2WithPort(8080), To: testconstants.HTTPSStackoverflow, Mocks: config.Mocks{ { @@ -178,9 +187,9 @@ func TestLoadConfiguration(t *testing.T) { expected: &config.UncorsConfig{ HTTPPort: 8080, Mappings: config.Mappings{ - {From: testconstants.HTTPLocalhost, To: testconstants.HTTPSGithub}, + {From: testconstants.HTTPLocalhostWithPort(8080), To: testconstants.HTTPSGithub}, { - From: testconstants.HTTPLocalhost2, + From: testconstants.HTTPLocalhost2WithPort(8080), To: testconstants.HTTPSStackoverflow, Mocks: config.Mocks{ { @@ -203,9 +212,12 @@ func TestLoadConfiguration(t *testing.T) { }, }, }, - {From: testconstants.SourceHost1, To: testconstants.TargetHost1}, - {From: testconstants.SourceHost2, To: testconstants.TargetHost2}, - {From: testconstants.SourceHost3, To: testconstants.TargetHost3}, + {From: HTTPf(testconstants.SourceHost1, 8080), To: testconstants.TargetHost1}, + {From: HTTPSf(testconstants.SourceHost1, 8081), To: testconstants.TargetHost1}, + {From: HTTPf(testconstants.SourceHost2, 8080), To: testconstants.TargetHost2}, + {From: HTTPSf(testconstants.SourceHost2, 8081), To: testconstants.TargetHost2}, + {From: HTTPf(testconstants.SourceHost3, 8080), To: testconstants.TargetHost3}, + {From: HTTPSf(testconstants.SourceHost3, 8081), To: testconstants.TargetHost3}, }, Proxy: "localhost:8080", Debug: true, @@ -223,8 +235,13 @@ func TestLoadConfiguration(t *testing.T) { }, }, } + for _, testCase := range tests { t.Run(testCase.name, func(t *testing.T) { + viper.Reset() + viperInstance := viper.New() + viperInstance.SetFs(fs) + uncorsConfig := config.LoadConfiguration(viperInstance, testCase.args) assert.Equal(t, testCase.expected, uncorsConfig) @@ -253,7 +270,7 @@ func TestLoadConfiguration(t *testing.T) { params.To, testconstants.TargetHost1, }, expected: []string{ - "recognize url mapping: `from` values are not set for every `to`", + "`from` values are not set for every `to`", }, }, { @@ -263,7 +280,7 @@ func TestLoadConfiguration(t *testing.T) { params.From, testconstants.SourceHost2, }, expected: []string{ - "recognize url mapping: `to` values are not set for every `from`", + "`to` values are not set for every `from`", }, }, { @@ -273,19 +290,18 @@ func TestLoadConfiguration(t *testing.T) { params.To, testconstants.TargetHost2, }, expected: []string{ - "recognize url mapping: `from` values are not set for every `to`", + "`from` values are not set for every `to`", + }, + }, + { + name: "config file doesn't exist", + args: []string{ + params.Config, "/not-exist-config.yaml", + }, + expected: []string{ + "filed to read config file '/not-exist-config.yaml': open /not-exist-config.yaml: file does not exist", }, }, - //{ - // name: "config file doesn't exist", - // args: []string{ - // params.Config, "/not-exist-config.yaml", - // }, - // expected: []string{ - // "filed to read config file '/not-exist-config.yaml': open ", - // "open /not-exist-config.yaml: file does not exist", - // }, - // }, { name: "config file is corrupted", args: []string{ @@ -318,8 +334,13 @@ func TestLoadConfiguration(t *testing.T) { }, } for _, testCase := range tests { + testCase := testCase t.Run(testCase.name, func(t *testing.T) { for _, expected := range testCase.expected { + viper.Reset() + viperInstance := viper.New() + viperInstance.SetFs(fs) + assert.PanicsWithError(t, expected, func() { config.LoadConfiguration(viperInstance, testCase.args) }) diff --git a/internal/config/helpers.go b/internal/config/helpers.go index 338405ff..d1013045 100644 --- a/internal/config/helpers.go +++ b/internal/config/helpers.go @@ -75,7 +75,7 @@ const ( ) func NormaliseMappings(mappings Mappings, httpPort int, httpsPort int, useHTTPS bool) Mappings { - var processedMappings Mappings + processedMappings := Mappings{} for _, mapping := range mappings { sourceURL, err := urlx.Parse(mapping.From) if err != nil { diff --git a/internal/helpers/graceful_shutdown.go b/internal/helpers/graceful_shutdown.go new file mode 100644 index 00000000..32394d5c --- /dev/null +++ b/internal/helpers/graceful_shutdown.go @@ -0,0 +1,43 @@ +package helpers + +import ( + "context" + "os" + "os/signal" + "syscall" +) + +var ( + notifyFn = signal.Notify + sigintFix = func() { + // fix prints after "^C" + println("") // nolint:forbidigo + } +) + +func GracefulShutdown(ctx context.Context, shutdownFunc func(ctx context.Context) error) error { + if done := waiteSignal(ctx); done { + return nil + } + + return shutdownFunc(ctx) +} + +func waiteSignal(ctx context.Context) bool { + stop := make(chan os.Signal, 1) + + notifyFn(stop, syscall.SIGINT, syscall.SIGTERM, syscall.SIGHUP) + + defer close(stop) + + select { + case sig := <-stop: + if sig == syscall.SIGINT { + sigintFix() + } + case <-ctx.Done(): + return true + } + + return false +} diff --git a/internal/helpers/graceful_shutdown_internal_test.go b/internal/helpers/graceful_shutdown_internal_test.go new file mode 100644 index 00000000..97e6dc6f --- /dev/null +++ b/internal/helpers/graceful_shutdown_internal_test.go @@ -0,0 +1,142 @@ +package helpers + +import ( + "context" + "os" + "os/signal" + "sync" + "syscall" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +type Env struct { + wg *sync.WaitGroup + afterAll []func() +} + +func (e *Env) Go(action func()) { + e.wg.Add(1) + go func() { + defer e.wg.Done() + action() + }() +} + +func (e *Env) CheckAfterAll(action func()) { + e.afterAll = append(e.afterAll, action) +} + +func WithGoroutines(test func(t *testing.T, env Env)) func(t *testing.T) { + return func(t *testing.T) { + env := Env{wg: &sync.WaitGroup{}} + test(t, env) + env.wg.Wait() + for _, f := range env.afterAll { + f() + } + } +} + +func TestGracefulShutdown(t *testing.T) { + t.Run("shutdown when context is done", WithGoroutines(func(t *testing.T, env Env) { + ctx, cancel := context.WithCancel(context.Background()) + + called := false + env.Go(func() { + err := GracefulShutdown(ctx, func(ctx context.Context) error { + called = true + + return nil + }) + assert.NoError(t, err) + }) + + env.CheckAfterAll(func() { + assert.True(t, called) + }) + + cancel() + })) + + t.Run("shutdown after system signal", func(t *testing.T) { + tests := []struct { + name string + signal os.Signal + }{ + { + name: "SIGINT", + signal: syscall.SIGINT, + }, + { + name: "SIGTERM", + signal: syscall.SIGTERM, + }, + { + name: "SIGHUP", + signal: syscall.SIGHUP, + }, + } + + for _, testCase := range tests { + t.Run(testCase.name, WithGoroutines(func(t *testing.T, env Env) { + var systemSig chan<- os.Signal + notifyFn = func(c chan<- os.Signal, sig ...os.Signal) { + systemSig = c + } + + t.Cleanup(func() { + notifyFn = signal.Notify + }) + + called := false + env.Go(func() { + err := GracefulShutdown(context.Background(), func(ctx context.Context) error { + called = true + + return nil + }) + assert.NoError(t, err) + }) + + <-time.After(50 * time.Millisecond) + systemSig <- testCase.signal + + env.CheckAfterAll(func() { + assert.True(t, called) + }) + })) + } + }) + + t.Run("apply additional ui fix for SIGINT signal", WithGoroutines(func(t *testing.T, env Env) { + var systemSig chan<- os.Signal + notifyFn = func(c chan<- os.Signal, sig ...os.Signal) { + systemSig = c + } + called := false + sigintFix = func() { + called = true + } + + t.Cleanup(func() { + notifyFn = signal.Notify + }) + + env.Go(func() { + err := GracefulShutdown(context.Background(), func(ctx context.Context) error { + return nil + }) + assert.NoError(t, err) + }) + + <-time.After(50 * time.Millisecond) + systemSig <- syscall.SIGINT + + env.CheckAfterAll(func() { + assert.True(t, called) + }) + })) +} diff --git a/internal/helpers/panic_test.go b/internal/helpers/panic_test.go index 2fd527a3..13ccc798 100644 --- a/internal/helpers/panic_test.go +++ b/internal/helpers/panic_test.go @@ -4,9 +4,10 @@ package helpers_test import ( "errors" - "github.com/evg4b/uncors/internal/helpers" "testing" + "github.com/evg4b/uncors/internal/helpers" + "github.com/stretchr/testify/assert" ) diff --git a/internal/server/atomic_bool.go b/internal/server/atomic_bool.go deleted file mode 100644 index 95543ef0..00000000 --- a/internal/server/atomic_bool.go +++ /dev/null @@ -1,17 +0,0 @@ -package server - -import "sync/atomic" - -type AtomicBool int32 - -func (b *AtomicBool) IsSet() bool { - return atomic.LoadInt32((*int32)(b)) != 0 -} - -func (b *AtomicBool) SetTrue() { - atomic.StoreInt32((*int32)(b), 1) -} - -func (b *AtomicBool) SetFalse() { - atomic.StoreInt32((*int32)(b), 0) -} diff --git a/internal/server/atomic_bool_test.go b/internal/server/atomic_bool_test.go deleted file mode 100644 index 5b9be3a9..00000000 --- a/internal/server/atomic_bool_test.go +++ /dev/null @@ -1,53 +0,0 @@ -package server_test - -import ( - "testing" - - "github.com/evg4b/uncors/internal/server" - "github.com/go-playground/assert/v2" -) - -func TestAtomicBool(t *testing.T) { - tests := []struct { - name string - getBool func() *server.AtomicBool - expected bool - }{ - { - name: "Should be false by default", - getBool: func() *server.AtomicBool { - var value server.AtomicBool - - return &value - }, - expected: false, - }, - { - name: "Should be false after SetFalse", - getBool: func() *server.AtomicBool { - var value server.AtomicBool - value.SetFalse() - - return &value - }, - expected: false, - }, - { - name: "Should be true after SetTrue", - getBool: func() *server.AtomicBool { - var value server.AtomicBool - value.SetTrue() - - return &value - }, - expected: true, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - value := tt.getBool() - - assert.IsEqual(value.IsSet(), tt.expected) - }) - } -} diff --git a/internal/server/server.go b/internal/server/server.go deleted file mode 100644 index 73f9b037..00000000 --- a/internal/server/server.go +++ /dev/null @@ -1,126 +0,0 @@ -//nolint:wrapcheck -package server - -import ( - "errors" - "net" - "net/http" - "time" - - "github.com/evg4b/uncors/internal/contracts" - "github.com/evg4b/uncors/internal/helpers" - "github.com/evg4b/uncors/internal/log" - "golang.org/x/net/context" -) - -type UncorsServer struct { - *http.Server - inShutdown AtomicBool -} - -const ( - readHeaderTimeout = 30 * time.Second - shutdownTimeout = 15 * time.Second -) - -func NewUncorsServer(ctx context.Context, handler contracts.Handler) *UncorsServer { - globalCtx, globalCtxCancel := context.WithCancel(ctx) - server := &http.Server{ - BaseContext: func(listener net.Listener) context.Context { - return globalCtx - }, - ReadHeaderTimeout: readHeaderTimeout, - Handler: http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) { - helpers.NormaliseRequest(request) - handler.ServeHTTP(contracts.WrapResponseWriter(writer), request) - }), - ErrorLog: log.StandardErrorLogAdapter(), - } - server.RegisterOnShutdown(globalCtxCancel) - - return &UncorsServer{ - Server: server, - } -} - -func (srv *UncorsServer) ListenAndServe(addr string) error { - if srv.shuttingDown() { - return http.ErrServerClosed - } - - if addr == "" { - addr = ":http" - } - - listener, err := net.Listen("tcp", addr) - if err != nil { - return err - } - - defer listener.Close() - - srv.Addr = listener.Addr().String() - err = srv.Serve(listener) - if err != nil { - srv.internalShutdown() - } - - return err -} - -func (srv *UncorsServer) ListenAndServeTLS(addr string, certFile, keyFile string) error { - if srv.shuttingDown() { - return http.ErrServerClosed - } - - if addr == "" { - addr = ":https" - } - - listener, err := net.Listen("tcp", addr) - if err != nil { - return err - } - - defer listener.Close() - - srv.Addr = listener.Addr().String() - err = srv.ServeTLS(listener, certFile, keyFile) - if err != nil { - srv.internalShutdown() - } - - return err -} - -func (srv *UncorsServer) Shutdown(ctx context.Context) error { - srv.inShutdown.SetTrue() - - return srv.Server.Shutdown(ctx) //nolint:wrapcheck -} - -func (srv *UncorsServer) Close() error { - srv.inShutdown.SetTrue() - - return srv.Server.Close() -} - -func (srv *UncorsServer) shuttingDown() bool { - return srv.inShutdown.IsSet() -} - -func (srv *UncorsServer) internalShutdown() { - ctx, cancel := context.WithTimeout(context.Background(), shutdownTimeout) - defer cancel() - log.Debug("uncors: shutting down ...") - err := srv.Shutdown(ctx) - if err != nil { - if errors.Is(err, context.DeadlineExceeded) { - log.Errorf("finish: shutdown timeout for UNCORS server") - } else { - log.Errorf("finish: error while shutting down UNCORS server: %s", err) - } - } else { - log.Debug("finish: UNCORS server closed") - } -} diff --git a/internal/server/server_test.go b/internal/server/server_test.go deleted file mode 100644 index a01a9baf..00000000 --- a/internal/server/server_test.go +++ /dev/null @@ -1,101 +0,0 @@ -package server_test - -import ( - "context" - "io" - "net/http" - "net/url" - "testing" - "time" - - "github.com/evg4b/uncors/internal/contracts" - "github.com/evg4b/uncors/internal/helpers" - "github.com/evg4b/uncors/internal/server" - "github.com/evg4b/uncors/testing/testutils" - "github.com/stretchr/testify/assert" -) - -func TestNewUncorsServer(t *testing.T) { - ctx := context.Background() - expectedResponse := "UNCORS OK!" - - var handler contracts.HandlerFunc = func(w contracts.ResponseWriter, _ *contracts.Request) { - w.WriteHeader(http.StatusOK) - helpers.Fprint(w, expectedResponse) - } - - t.Run("handle request", func(t *testing.T) { - t.Run("HTTP", func(t *testing.T) { - uncorsServer := server.NewUncorsServer(ctx, handler) - defer func() { - err := uncorsServer.Close() - testutils.CheckNoServerError(t, err) - }() - - go func() { - err := uncorsServer.ListenAndServe("127.0.0.1:0") - testutils.CheckNoServerError(t, err) - }() - - time.Sleep(300 * time.Millisecond) - uri, err := url.Parse("http://" + uncorsServer.Addr) - testutils.CheckNoError(t, err) - - res, err := http.DefaultClient.Do(&http.Request{URL: uri, Method: http.MethodGet}) - testutils.CheckNoError(t, err) - defer helpers.CloseSafe(res.Body) - - data, err := io.ReadAll(res.Body) - testutils.CheckNoError(t, err) - - assert.Equal(t, expectedResponse, string(data)) - }) - - t.Run("HTTPS", testutils.WithTmpCerts(func(t *testing.T, certs *testutils.Certs) { - uncorsServer := server.NewUncorsServer(ctx, handler) - defer func() { - testutils.CheckNoServerError(t, uncorsServer.Close()) - }() - - go func() { - err := uncorsServer.ListenAndServeTLS("127.0.0.1:0", certs.CertPath, certs.KeyPath) - testutils.CheckNoServerError(t, err) - }() - - httpClient := http.Client{ - Transport: &http.Transport{ - TLSClientConfig: certs.ClientTLSConf, - }, - } - - time.Sleep(300 * time.Millisecond) - uri, err := url.Parse("https://" + uncorsServer.Addr) - testutils.CheckNoError(t, err) - - response, err := httpClient.Do(&http.Request{URL: uri, Method: http.MethodGet}) - testutils.CheckNoError(t, err) - defer helpers.CloseSafe(response.Body) - - actualResponse, err := io.ReadAll(response.Body) - testutils.CheckNoError(t, err) - - assert.Equal(t, expectedResponse, string(actualResponse)) - })) - }) - - t.Run("run already stopped server", func(t *testing.T) { - uncorsServer := server.NewUncorsServer(ctx, handler) - testutils.CheckNoServerError(t, uncorsServer.Close()) - - t.Run("HTTP", func(t *testing.T) { - err := uncorsServer.ListenAndServe("127.0.0.1:0") - - assert.ErrorIs(t, err, http.ErrServerClosed) - }) - t.Run("HTTPS", testutils.WithTmpCerts(func(t *testing.T, certs *testutils.Certs) { - err := uncorsServer.ListenAndServeTLS("127.0.0.1:0", certs.CertPath, certs.KeyPath) - - assert.ErrorIs(t, err, http.ErrServerClosed) - })) - }) -} diff --git a/internal/uncors/app.go b/internal/uncors/app.go new file mode 100644 index 00000000..4eeff2b6 --- /dev/null +++ b/internal/uncors/app.go @@ -0,0 +1,156 @@ +// nolint: wrapcheck +package uncors + +import ( + "errors" + "fmt" + "net" + "net/http" + "strconv" + "sync" + "sync/atomic" + "time" + + "github.com/evg4b/uncors/internal/config" + "github.com/evg4b/uncors/internal/contracts" + "github.com/evg4b/uncors/internal/helpers" + "github.com/evg4b/uncors/internal/log" + "github.com/spf13/afero" + "golang.org/x/net/context" +) + +type App struct { + fs afero.Fs + version string + waitGroup *sync.WaitGroup + httpMutex *sync.Mutex + httpsMutex *sync.Mutex + server *http.Server + shuttingDown *atomic.Bool + httpListener net.Listener + httpsListener net.Listener + cache appCache +} + +const ( + baseAddress = "127.0.0.1" + readHeaderTimeout = 30 * time.Second + shutdownTimeout = 15 * time.Second +) + +func CreateApp(fs afero.Fs, version string) *App { + return &App{ + fs: fs, + version: version, + waitGroup: &sync.WaitGroup{}, + httpMutex: &sync.Mutex{}, + httpsMutex: &sync.Mutex{}, + shuttingDown: &atomic.Bool{}, + } +} + +func (app *App) Start(ctx context.Context, uncorsConfig *config.UncorsConfig) { + log.Print(Logo(app.version)) + log.Print("\n") + log.Warning(DisclaimerMessage) + log.Print("\n") + log.Info(uncorsConfig.Mappings.String()) + log.Print("\n") + + app.initServer(ctx, uncorsConfig) +} + +func (app *App) initServer(ctx context.Context, uncorsConfig *config.UncorsConfig) { + app.shuttingDown.Store(false) + app.server = app.createServer(ctx, uncorsConfig) + + app.waitGroup.Add(1) + go func() { + defer app.waitGroup.Done() + defer app.httpMutex.Unlock() + + app.httpMutex.Lock() + log.Debugf("Starting http server on port %d", uncorsConfig.HTTPPort) + addr := net.JoinHostPort(baseAddress, strconv.Itoa(uncorsConfig.HTTPPort)) + err := app.listenAndServe(addr) + handleHTTPServerError("HTTP", err) + }() + + if uncorsConfig.IsHTTPSEnabled() { + log.Debug("Found cert file and key file. Https server will be started") + addr := net.JoinHostPort(baseAddress, strconv.Itoa(uncorsConfig.HTTPSPort)) + app.waitGroup.Add(1) + go func() { + defer app.waitGroup.Done() + defer app.httpsMutex.Unlock() + + app.httpsMutex.Lock() + log.Debugf("Starting https server on port %d", uncorsConfig.HTTPSPort) + err := app.listenAndServeTLS(addr, uncorsConfig.CertFile, uncorsConfig.KeyFile) + handleHTTPServerError("HTTPS", err) + }() + } +} + +func (app *App) createServer(ctx context.Context, uncorsConfig *config.UncorsConfig) *http.Server { + globalHandler := app.buildHandler(uncorsConfig) + globalCtx, globalCtxCancel := context.WithCancel(ctx) + server := &http.Server{ + BaseContext: func(listener net.Listener) context.Context { + return globalCtx + }, + ReadHeaderTimeout: readHeaderTimeout, + Handler: http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) { + helpers.NormaliseRequest(request) + globalHandler.ServeHTTP(contracts.WrapResponseWriter(writer), request) + }), + ErrorLog: log.StandardErrorLogAdapter(), + } + server.RegisterOnShutdown(globalCtxCancel) + + return server +} + +func (app *App) Restart(ctx context.Context, uncorsConfig *config.UncorsConfig) { + defer app.waitGroup.Done() + app.waitGroup.Add(1) + log.Print("\n") + log.Info("Restarting server....") + log.Print("\n") + err := app.internalShutdown(ctx) + if err != nil { + panic(err) // TODO: refactor this error handling + } + + log.Info(uncorsConfig.Mappings.String()) + log.Print("\n") + app.initServer(ctx, uncorsConfig) +} + +func (app *App) Close() error { + return app.server.Close() +} + +func (app *App) Wait() { + app.waitGroup.Wait() +} + +func (app *App) Shutdown(ctx context.Context) error { + return app.internalShutdown(ctx) +} + +func (app *App) HTTPAddr() net.Addr { + return app.httpListener.Addr() // TODO: Add nil handing +} + +func (app *App) HTTPSAddr() net.Addr { + return app.httpsListener.Addr() // TODO: Add nil handing +} + +func handleHTTPServerError(serverName string, err error) { + if err == nil || errors.Is(err, http.ErrServerClosed) { + log.Debugf("%s server was stopped without errors", serverName) + } else { + panic(fmt.Errorf("%s server was stopped with error %w", serverName, err)) + } +} diff --git a/internal/uncors/app_test.go b/internal/uncors/app_test.go new file mode 100644 index 00000000..dcdf498e --- /dev/null +++ b/internal/uncors/app_test.go @@ -0,0 +1,217 @@ +package uncors_test + +import ( + "bytes" + "context" + "io" + "net/http" + "net/url" + "testing" + "time" + + "github.com/evg4b/uncors/internal/config" + "github.com/evg4b/uncors/internal/helpers" + "github.com/evg4b/uncors/internal/uncors" + "github.com/evg4b/uncors/testing/testutils" + "github.com/phayes/freeport" + "github.com/spf13/afero" + "github.com/stretchr/testify/assert" +) + +const delay = 10 * time.Millisecond + +func TestUncorsApp(t *testing.T) { + ctx := context.Background() + fs := afero.NewOsFs() + expectedResponse := "UNCORS OK!" + + t.Run("handle request", testutils.LogTest(func(t *testing.T, output *bytes.Buffer) { + t.Run("HTTP", func(t *testing.T) { + uncorsApp, uri := createApp(ctx, t, fs, false, &config.UncorsConfig{ + HTTPPort: freeport.GetPort(), + Mappings: config.Mappings{ + config.Mapping{ + From: "http://127.0.0.1", + To: "https://github.com", + Mocks: mocks(expectedResponse), + }, + }, + }) + defer func() { + err := uncorsApp.Close() + testutils.CheckNoServerError(t, err) + }() + + response := makeRequest(t, http.DefaultClient, uri) + + assert.Equal(t, expectedResponse, response) + }) + + t.Run("HTTPS", testutils.WithTmpCerts(fs, func(t *testing.T, certs *testutils.Certs) { + uncorsApp, uri := createApp(ctx, t, fs, true, &config.UncorsConfig{ + HTTPSPort: freeport.GetPort(), + CertFile: certs.CertPath, + KeyFile: certs.KeyPath, + Mappings: config.Mappings{ + config.Mapping{ + From: "https://127.0.0.1", + To: "https://github.com", + Mocks: mocks(expectedResponse), + }, + }, + }) + defer func() { + err := uncorsApp.Close() + testutils.CheckNoServerError(t, err) + }() + + httpClient := &http.Client{ + Transport: &http.Transport{ + TLSClientConfig: certs.ClientTLSConf, + }, + } + + response := makeRequest(t, httpClient, uri) + + assert.Equal(t, expectedResponse, response) + })) + })) + + t.Run("restart server", testutils.LogTest(func(t *testing.T, output *bytes.Buffer) { + const otherExpectedRepose = `{ "bla": true }` + + t.Run("HTTP", func(t *testing.T) { + port := freeport.GetPort() + uncorsApp, uri := createApp(ctx, t, fs, false, &config.UncorsConfig{ + HTTPPort: port, + Mappings: config.Mappings{ + config.Mapping{ + From: "http://127.0.0.1", + To: "https://github.com", + Mocks: mocks(expectedResponse), + }, + }, + }) + defer func() { + err := uncorsApp.Close() + testutils.CheckNoServerError(t, err) + }() + + response := makeRequest(t, http.DefaultClient, uri) + assert.Equal(t, expectedResponse, response) + + uncorsApp.Restart(ctx, &config.UncorsConfig{ + HTTPPort: port, + Mappings: config.Mappings{ + config.Mapping{ + From: "https://127.0.0.1", + To: "https://github.com", + Mocks: mocks(otherExpectedRepose), + }, + }, + }) + + time.Sleep(delay) + + response2 := makeRequest(t, http.DefaultClient, uri) + + assert.Equal(t, otherExpectedRepose, response2) + }) + + t.Run("HTTPS", testutils.WithTmpCerts(fs, func(t *testing.T, certs *testutils.Certs) { + port := freeport.GetPort() + uncorsApp, uri := createApp(ctx, t, fs, true, &config.UncorsConfig{ + HTTPSPort: port, + CertFile: certs.CertPath, + KeyFile: certs.KeyPath, + Mappings: config.Mappings{ + config.Mapping{ + From: "https://127.0.0.1", + To: "https://github.com", + Mocks: mocks(expectedResponse), + }, + }, + }) + defer func() { + err := uncorsApp.Close() + testutils.CheckNoServerError(t, err) + }() + + httpClient := &http.Client{ + Transport: &http.Transport{ + TLSClientConfig: certs.ClientTLSConf, + }, + } + + response := makeRequest(t, httpClient, uri) + + assert.Equal(t, expectedResponse, response) + + uncorsApp.Restart(ctx, &config.UncorsConfig{ + HTTPSPort: port, + CertFile: certs.CertPath, + KeyFile: certs.KeyPath, + Mappings: config.Mappings{ + config.Mapping{ + From: "https://127.0.0.1", + To: "https://github.com", + Mocks: mocks(otherExpectedRepose), + }, + }, + }) + + time.Sleep(delay) + + response2 := makeRequest(t, httpClient, uri) + + assert.Equal(t, otherExpectedRepose, response2) + })) + })) +} + +func makeRequest(t *testing.T, httpClient *http.Client, uri *url.URL) string { + res, err := httpClient.Do(&http.Request{URL: uri, Method: http.MethodGet}) + testutils.CheckNoError(t, err) + defer helpers.CloseSafe(res.Body) + + data, err := io.ReadAll(res.Body) + testutils.CheckNoError(t, err) + + return string(data) +} + +func createApp( + ctx context.Context, + t *testing.T, fs afero.Fs, https bool, config *config.UncorsConfig, +) (*uncors.App, *url.URL) { + app := uncors.CreateApp(fs, "x.x.x") + + go app.Start(ctx, config) + + time.Sleep(delay) + + prefix := "http://" + if https { + prefix = "https://" + } + addr := app.HTTPAddr().String() + if https { + addr = app.HTTPSAddr().String() + } + uri, err := url.Parse(prefix + addr) + testutils.CheckNoError(t, err) + + return app, uri +} + +func mocks(response string) config.Mocks { + return config.Mocks{ + config.Mock{ + Path: "/", + Response: config.Response{ + Code: http.StatusOK, + Raw: response, + }, + }, + } +} diff --git a/internal/uncors/handler.go b/internal/uncors/handler.go new file mode 100644 index 00000000..d0d1525b --- /dev/null +++ b/internal/uncors/handler.go @@ -0,0 +1,88 @@ +package uncors + +import ( + "time" + + "github.com/evg4b/uncors/internal/config" + "github.com/evg4b/uncors/internal/contracts" + "github.com/evg4b/uncors/internal/handler" + cache2 "github.com/evg4b/uncors/internal/handler/cache" + "github.com/evg4b/uncors/internal/handler/mock" + "github.com/evg4b/uncors/internal/handler/proxy" + "github.com/evg4b/uncors/internal/handler/static" + "github.com/evg4b/uncors/internal/infra" + "github.com/evg4b/uncors/internal/urlreplacer" + "github.com/patrickmn/go-cache" + "github.com/spf13/afero" +) + +type appCache struct { + staticHandlerFactory handler.RequestHandlerOption + mockHandlerFactory handler.RequestHandlerOption +} + +func (app *App) buildHandler(uncorsConfig *config.UncorsConfig) *handler.RequestHandler { + globalHandler := handler.NewUncorsRequestHandler( + handler.WithMappings(uncorsConfig.Mappings), + handler.WithLogger(MockLogger), + handler.WithCacheMiddlewareFactory(func(globs config.CacheGlobs) contracts.Middleware { + cacheConfig := uncorsConfig.CacheConfig + // TODO: Add cache storage reusage + cacheStorage := cache.New(cacheConfig.ExpirationTime, cacheConfig.ClearTime) + + return cache2.NewMiddleware( + cache2.WithLogger(CacheLogger), + cache2.WithMethods(cacheConfig.Methods), + cache2.WithCacheStorage(cacheStorage), + cache2.WithGlobs(globs), + ) + }), + handler.WithProxyHandlerFactory(func() contracts.Handler { + factory := urlreplacer.NewURLReplacerFactory(uncorsConfig.Mappings) + httpClient := infra.MakeHTTPClient(uncorsConfig.Proxy) + + return proxy.NewProxyHandler( + proxy.WithURLReplacerFactory(factory), + proxy.WithHTTPClient(httpClient), + proxy.WithLogger(ProxyLogger), + ) + }), + app.getWithStaticHandlerFactory(), + app.getMockHandlerFactory(), + ) + + return globalHandler +} + +func (app *App) getMockHandlerFactory() handler.RequestHandlerOption { + if app.cache.mockHandlerFactory == nil { + factoryFunc := func(response config.Response) contracts.Handler { + return mock.NewMockHandler( + mock.WithLogger(MockLogger), + mock.WithResponse(response), + mock.WithFileSystem(app.fs), + mock.WithAfter(time.After), + ) + } + app.cache.mockHandlerFactory = handler.WithMockHandlerFactory(factoryFunc) + } + + return app.cache.mockHandlerFactory +} + +func (app *App) getWithStaticHandlerFactory() handler.RequestHandlerOption { + if app.cache.staticHandlerFactory == nil { + factoryFunc := func(path string, dir config.StaticDirectory) contracts.Middleware { + return static.NewStaticMiddleware( + static.WithFileSystem(afero.NewBasePathFs(app.fs, dir.Dir)), + static.WithIndex(dir.Index), + static.WithLogger(StaticLogger), + static.WithPrefix(path), + ) + } + + app.cache.staticHandlerFactory = handler.WithStaticHandlerFactory(factoryFunc) + } + + return app.cache.staticHandlerFactory +} diff --git a/internal/uncors/listen.go b/internal/uncors/listen.go new file mode 100644 index 00000000..2acdc5ca --- /dev/null +++ b/internal/uncors/listen.go @@ -0,0 +1,63 @@ +// nolint: wrapcheck +package uncors + +import ( + "context" + "errors" + "net" + "net/http" +) + +type serveConfig struct { + addr string + serve func(l net.Listener) error + setListener func(l net.Listener) +} + +func (app *App) listenAndServe(addr string) error { + return app.internalServe(&serveConfig{ + addr: addr, + serve: app.server.Serve, // nolint: wrapcheck + setListener: func(l net.Listener) { + app.httpListener = l + }, + }) +} + +func (app *App) listenAndServeTLS(addr string, certFile, keyFile string) error { + return app.internalServe(&serveConfig{ + addr: addr, + serve: func(l net.Listener) error { + return app.server.ServeTLS(l, certFile, keyFile) // nolint: wrapcheck + }, + setListener: func(l net.Listener) { + app.httpsListener = l + }, + }) +} + +func (app *App) internalServe(config *serveConfig) error { + if app.shuttingDown.Load() { + return http.ErrServerClosed + } + + listener, err := net.Listen("tcp", config.addr) + if err != nil { + return err + } + + config.setListener(listener) + defer func() { config.setListener(nil) }() + + err = config.serve(listener) + if err != nil && !errors.Is(err, http.ErrServerClosed) { + shutdownError := app.internalShutdown(context.TODO()) + if shutdownError != nil && !errors.Is(shutdownError, http.ErrServerClosed) { + panic(shutdownError) + } + + return err + } + + return nil +} diff --git a/internal/ui/loggers.go b/internal/uncors/loggers.go similarity index 97% rename from internal/ui/loggers.go rename to internal/uncors/loggers.go index 858cea86..82df3531 100644 --- a/internal/ui/loggers.go +++ b/internal/uncors/loggers.go @@ -1,4 +1,4 @@ -package ui +package uncors import ( "github.com/evg4b/uncors/internal/log" diff --git a/internal/ui/logo.go b/internal/uncors/logo.go similarity index 98% rename from internal/ui/logo.go rename to internal/uncors/logo.go index 1dd0d509..c76599f8 100644 --- a/internal/ui/logo.go +++ b/internal/uncors/logo.go @@ -1,4 +1,4 @@ -package ui +package uncors import ( "strings" diff --git a/internal/ui/logo_test.go b/internal/uncors/logo_test.go similarity index 97% rename from internal/ui/logo_test.go rename to internal/uncors/logo_test.go index dbcc3acd..98bdc444 100644 --- a/internal/ui/logo_test.go +++ b/internal/uncors/logo_test.go @@ -1,9 +1,10 @@ -package ui_test +package uncors_test import ( "testing" - "github.com/evg4b/uncors/internal/ui" + "github.com/evg4b/uncors/internal/uncors" + "github.com/pterm/pterm" "github.com/stretchr/testify/assert" ) @@ -45,6 +46,6 @@ var expectedLogo = []byte{ func TestLogo(t *testing.T) { pterm.DisableColor() - logo := ui.Logo("X.Y.Z") + logo := uncors.Logo("X.Y.Z") assert.Equal(t, expectedLogo, []byte(logo)) } diff --git a/internal/ui/messages.go b/internal/uncors/messages.go similarity index 95% rename from internal/ui/messages.go rename to internal/uncors/messages.go index eca97576..46cf7e35 100644 --- a/internal/ui/messages.go +++ b/internal/uncors/messages.go @@ -1,4 +1,4 @@ -package ui +package uncors const DisclaimerMessage = `DON'T USE IT FOR PRODUCTION! This is a reverse proxy for use in testing or debugging web applications locally. diff --git a/internal/uncors/shutdown.go b/internal/uncors/shutdown.go new file mode 100644 index 00000000..7cbbe1f4 --- /dev/null +++ b/internal/uncors/shutdown.go @@ -0,0 +1,34 @@ +package uncors + +import ( + "context" + "errors" + + "github.com/evg4b/uncors/internal/log" +) + +func (app *App) internalShutdown(rootCtx context.Context) error { + if app.server == nil { + return nil + } + + app.shuttingDown.Store(true) + ctx, cancel := context.WithTimeout(rootCtx, shutdownTimeout) + defer cancel() + + log.Debug("uncors: shutting down ...") + + if err := app.server.Shutdown(ctx); err != nil { + if errors.Is(err, context.DeadlineExceeded) { + log.Errorf("shutdown timeout for UNCORS server") + } else { + log.Errorf("error while shutting down UNCORS server: %s", err) + } + + return err // nolint: wrapcheck + } + + log.Debug("UNCORS server closed") + + return nil +} diff --git a/internal/version/new_version_check.go b/internal/version/new_version_check.go index b84608c8..c2de20e7 100644 --- a/internal/version/new_version_check.go +++ b/internal/version/new_version_check.go @@ -10,7 +10,7 @@ import ( "github.com/evg4b/uncors/internal/contracts" "github.com/evg4b/uncors/internal/helpers" "github.com/evg4b/uncors/internal/log" - "github.com/evg4b/uncors/internal/ui" + "github.com/evg4b/uncors/internal/uncors" "github.com/hashicorp/go-version" ) @@ -63,7 +63,7 @@ func CheckNewVersion(ctx context.Context, client contracts.HTTPClient, rawCurren } if lastVersion.GreaterThan(currentVersion) { - log.Infof(ui.NewVersionIsAvailable, currentVersion.String(), lastVersion.String()) + log.Infof(uncors.NewVersionIsAvailable, currentVersion.String(), lastVersion.String()) log.Info("\n") } else { log.Debug("Version is up to date") diff --git a/main.go b/main.go index 788b9389..cf143f10 100644 --- a/main.go +++ b/main.go @@ -1,30 +1,16 @@ +// nolint: wrapcheck package main import ( - "errors" - "fmt" - "net" - "net/http" "os" - "strconv" - "time" - cf "github.com/evg4b/uncors/internal/config" - c "github.com/evg4b/uncors/internal/contracts" - "github.com/evg4b/uncors/internal/handler" - "github.com/evg4b/uncors/internal/handler/cache" - "github.com/evg4b/uncors/internal/handler/mock" - "github.com/evg4b/uncors/internal/handler/proxy" - "github.com/evg4b/uncors/internal/handler/static" + "github.com/evg4b/uncors/internal/config" "github.com/evg4b/uncors/internal/helpers" "github.com/evg4b/uncors/internal/infra" "github.com/evg4b/uncors/internal/log" - "github.com/evg4b/uncors/internal/server" - "github.com/evg4b/uncors/internal/ui" - "github.com/evg4b/uncors/internal/urlreplacer" + "github.com/evg4b/uncors/internal/uncors" "github.com/evg4b/uncors/internal/version" - goCache "github.com/patrickmn/go-cache" - "github.com/pseidemann/finish" + "github.com/fsnotify/fsnotify" "github.com/spf13/afero" "github.com/spf13/pflag" "github.com/spf13/viper" @@ -33,8 +19,6 @@ import ( var Version = "X.X.X" -const baseAddress = "127.0.0.1" - func main() { defer helpers.PanicInterceptor(func(value any) { log.Error(value) @@ -42,116 +26,49 @@ func main() { }) pflag.Usage = func() { - ui.Logo(Version) + uncors.Logo(Version) helpers.Fprintf(os.Stdout, "Usage of %s:\n", os.Args[0]) pflag.PrintDefaults() } - uncorsConfig := cf.LoadConfiguration(viper.GetViper(), os.Args) - - if err := cf.Validate(uncorsConfig); err != nil { - panic(err) - } - - if uncorsConfig.Debug { - log.EnableDebugMessages() - log.Debug("Enabled debug messages") - } - - mappings := cf.NormaliseMappings( - uncorsConfig.Mappings, - uncorsConfig.HTTPPort, - uncorsConfig.HTTPSPort, - uncorsConfig.IsHTTPSEnabled(), - ) - - factory := urlreplacer.NewURLReplacerFactory(mappings) - httpClient := infra.MakeHTTPClient(viper.GetString("proxy")) - - cacheConfig := uncorsConfig.CacheConfig - cacheStorage := goCache.New(cacheConfig.ExpirationTime, cacheConfig.ClearTime) + viperInstance := viper.GetViper() + uncorsConfig := loadConfiguration(viperInstance) fs := afero.NewOsFs() - globalHandler := handler.NewUncorsRequestHandler( - handler.WithMappings(mappings), - handler.WithLogger(ui.MockLogger), - handler.WithCacheMiddlewareFactory(func(globs cf.CacheGlobs) c.Middleware { - return cache.NewMiddleware( - cache.WithLogger(ui.CacheLogger), - cache.WithMethods(cacheConfig.Methods), - cache.WithCacheStorage(cacheStorage), - cache.WithGlobs(globs), - ) - }), - handler.WithProxyHandlerFactory(func() c.Handler { - return proxy.NewProxyHandler( - proxy.WithURLReplacerFactory(factory), - proxy.WithHTTPClient(httpClient), - proxy.WithLogger(ui.ProxyLogger), - ) - }), - handler.WithStaticHandlerFactory(func(path string, dir cf.StaticDirectory) c.Middleware { - return static.NewStaticMiddleware( - static.WithFileSystem(afero.NewBasePathFs(fs, dir.Dir)), - static.WithIndex(dir.Index), - static.WithLogger(ui.StaticLogger), - static.WithPrefix(path), - ) - }), - handler.WithMockHandlerFactory(func(response cf.Response) c.Handler { - return mock.NewMockHandler( - mock.WithLogger(ui.MockLogger), - mock.WithResponse(response), - mock.WithFileSystem(fs), - mock.WithAfter(time.After), - ) - }), - ) - - finisher := finish.Finisher{Log: infra.NoopLogger{}} - ctx := context.Background() + app := uncors.CreateApp(fs, Version) + viperInstance.OnConfigChange(func(in fsnotify.Event) { + defer helpers.PanicInterceptor(func(value any) { + log.Errorf("Config reloading value %v", value) + }) - uncorsServer := server.NewUncorsServer(ctx, globalHandler) - - log.Print(ui.Logo(Version)) - log.Print("\n") - log.Warning(ui.DisclaimerMessage) - log.Print("\n") - log.Info(mappings.String()) - log.Print("\n") - - finisher.Add(uncorsServer) + app.Restart(ctx, loadConfiguration(viperInstance)) + }) + viperInstance.WatchConfig() + go version.CheckNewVersion(ctx, infra.MakeHTTPClient(uncorsConfig.Proxy), Version) + app.Start(ctx, uncorsConfig) go func() { - defer finisher.Trigger() - log.Debugf("Starting http server on port %d", uncorsConfig.HTTPPort) - addr := net.JoinHostPort(baseAddress, strconv.Itoa(uncorsConfig.HTTPPort)) - err := uncorsServer.ListenAndServe(addr) - handleHTTPServerError("HTTP", err) + shutdownErr := helpers.GracefulShutdown(ctx, func(shutdownCtx context.Context) error { + log.Debug("shutdown signal received") + + return app.Shutdown(shutdownCtx) + }) + if shutdownErr != nil { + panic(shutdownErr) + } }() - - if uncorsConfig.IsHTTPSEnabled() { - log.Debug("Found cert file and key file. Https server will be started") - addr := net.JoinHostPort(baseAddress, strconv.Itoa(uncorsConfig.HTTPSPort)) - go func() { - defer finisher.Trigger() - log.Debugf("Starting https server on port %d", uncorsConfig.HTTPSPort) - err := uncorsServer.ListenAndServeTLS(addr, uncorsConfig.CertFile, uncorsConfig.KeyFile) - handleHTTPServerError("HTTPS", err) - }() - } - - go version.CheckNewVersion(ctx, httpClient, Version) - - finisher.Wait() - + app.Wait() log.Info("Server was stopped") } -func handleHTTPServerError(serverName string, err error) { - if err == nil || errors.Is(err, http.ErrServerClosed) { - log.Debugf("%s server was stopped without errors", serverName) +func loadConfiguration(viperInstance *viper.Viper) *config.UncorsConfig { + uncorsConfig := config.LoadConfiguration(viperInstance, os.Args) + if uncorsConfig.Debug { + log.EnableDebugMessages() + log.Debug("Enabled debug messages") } else { - panic(fmt.Errorf("%s server was stopped with error %w", serverName, err)) + log.DisableDebugMessages() } + + return uncorsConfig } diff --git a/testing/mocks/closer_mock.go b/testing/mocks/closer_mock.go index 281ecdce..38a6ace5 100644 --- a/testing/mocks/closer_mock.go +++ b/testing/mocks/closer_mock.go @@ -90,7 +90,7 @@ func (mmClose *mCloserMockClose) Return(err error) *CloserMock { return mmClose.mock } -//Set uses given function f to mock the Closer.Close method +// Set uses given function f to mock the Closer.Close method func (mmClose *mCloserMockClose) Set(f func() (err error)) *CloserMock { if mmClose.defaultExpectation != nil { mmClose.mock.t.Fatalf("Default expectation is already set for the Closer.Close method") diff --git a/testing/mocks/http_client_mock.go b/testing/mocks/http_client_mock.go index 518a6deb..b89e6402 100644 --- a/testing/mocks/http_client_mock.go +++ b/testing/mocks/http_client_mock.go @@ -109,7 +109,7 @@ func (mmDo *mHTTPClientMockDo) Return(rp1 *http.Response, err error) *HTTPClient return mmDo.mock } -//Set uses given function f to mock the HTTPClient.Do method +// Set uses given function f to mock the HTTPClient.Do method func (mmDo *mHTTPClientMockDo) Set(f func(req *http.Request) (rp1 *http.Response, err error)) *HTTPClientMock { if mmDo.defaultExpectation != nil { mmDo.mock.t.Fatalf("Default expectation is already set for the HTTPClient.Do method") diff --git a/testing/mocks/logger_mock.go b/testing/mocks/logger_mock.go index d585ce70..fb1a4d84 100644 --- a/testing/mocks/logger_mock.go +++ b/testing/mocks/logger_mock.go @@ -175,7 +175,7 @@ func (mmDebug *mLoggerMockDebug) Return() *LoggerMock { return mmDebug.mock } -//Set uses given function f to mock the Logger.Debug method +// Set uses given function f to mock the Logger.Debug method func (mmDebug *mLoggerMockDebug) Set(f func(a ...any)) *LoggerMock { if mmDebug.defaultExpectation != nil { mmDebug.mock.t.Fatalf("Default expectation is already set for the Logger.Debug method") @@ -228,7 +228,6 @@ func (mmDebug *LoggerMock) Debug(a ...any) { return } mmDebug.t.Fatalf("Unexpected call to LoggerMock.Debug. %v", a) - } // DebugAfterCounter returns a count of finished LoggerMock.Debug invocations @@ -363,7 +362,7 @@ func (mmDebugf *mLoggerMockDebugf) Return() *LoggerMock { return mmDebugf.mock } -//Set uses given function f to mock the Logger.Debugf method +// Set uses given function f to mock the Logger.Debugf method func (mmDebugf *mLoggerMockDebugf) Set(f func(template string, a ...any)) *LoggerMock { if mmDebugf.defaultExpectation != nil { mmDebugf.mock.t.Fatalf("Default expectation is already set for the Logger.Debugf method") @@ -416,7 +415,6 @@ func (mmDebugf *LoggerMock) Debugf(template string, a ...any) { return } mmDebugf.t.Fatalf("Unexpected call to LoggerMock.Debugf. %v %v", template, a) - } // DebugfAfterCounter returns a count of finished LoggerMock.Debugf invocations @@ -550,7 +548,7 @@ func (mmError *mLoggerMockError) Return() *LoggerMock { return mmError.mock } -//Set uses given function f to mock the Logger.Error method +// Set uses given function f to mock the Logger.Error method func (mmError *mLoggerMockError) Set(f func(a ...any)) *LoggerMock { if mmError.defaultExpectation != nil { mmError.mock.t.Fatalf("Default expectation is already set for the Logger.Error method") @@ -603,7 +601,6 @@ func (mmError *LoggerMock) Error(a ...any) { return } mmError.t.Fatalf("Unexpected call to LoggerMock.Error. %v", a) - } // ErrorAfterCounter returns a count of finished LoggerMock.Error invocations @@ -738,7 +735,7 @@ func (mmErrorf *mLoggerMockErrorf) Return() *LoggerMock { return mmErrorf.mock } -//Set uses given function f to mock the Logger.Errorf method +// Set uses given function f to mock the Logger.Errorf method func (mmErrorf *mLoggerMockErrorf) Set(f func(template string, a ...any)) *LoggerMock { if mmErrorf.defaultExpectation != nil { mmErrorf.mock.t.Fatalf("Default expectation is already set for the Logger.Errorf method") @@ -791,7 +788,6 @@ func (mmErrorf *LoggerMock) Errorf(template string, a ...any) { return } mmErrorf.t.Fatalf("Unexpected call to LoggerMock.Errorf. %v %v", template, a) - } // ErrorfAfterCounter returns a count of finished LoggerMock.Errorf invocations @@ -925,7 +921,7 @@ func (mmInfo *mLoggerMockInfo) Return() *LoggerMock { return mmInfo.mock } -//Set uses given function f to mock the Logger.Info method +// Set uses given function f to mock the Logger.Info method func (mmInfo *mLoggerMockInfo) Set(f func(a ...any)) *LoggerMock { if mmInfo.defaultExpectation != nil { mmInfo.mock.t.Fatalf("Default expectation is already set for the Logger.Info method") @@ -978,7 +974,6 @@ func (mmInfo *LoggerMock) Info(a ...any) { return } mmInfo.t.Fatalf("Unexpected call to LoggerMock.Info. %v", a) - } // InfoAfterCounter returns a count of finished LoggerMock.Info invocations @@ -1113,7 +1108,7 @@ func (mmInfof *mLoggerMockInfof) Return() *LoggerMock { return mmInfof.mock } -//Set uses given function f to mock the Logger.Infof method +// Set uses given function f to mock the Logger.Infof method func (mmInfof *mLoggerMockInfof) Set(f func(template string, a ...any)) *LoggerMock { if mmInfof.defaultExpectation != nil { mmInfof.mock.t.Fatalf("Default expectation is already set for the Logger.Infof method") @@ -1166,7 +1161,6 @@ func (mmInfof *LoggerMock) Infof(template string, a ...any) { return } mmInfof.t.Fatalf("Unexpected call to LoggerMock.Infof. %v %v", template, a) - } // InfofAfterCounter returns a count of finished LoggerMock.Infof invocations @@ -1301,7 +1295,7 @@ func (mmPrintResponse *mLoggerMockPrintResponse) Return() *LoggerMock { return mmPrintResponse.mock } -//Set uses given function f to mock the Logger.PrintResponse method +// Set uses given function f to mock the Logger.PrintResponse method func (mmPrintResponse *mLoggerMockPrintResponse) Set(f func(request *mm_contracts.Request, code int)) *LoggerMock { if mmPrintResponse.defaultExpectation != nil { mmPrintResponse.mock.t.Fatalf("Default expectation is already set for the Logger.PrintResponse method") @@ -1354,7 +1348,6 @@ func (mmPrintResponse *LoggerMock) PrintResponse(request *mm_contracts.Request, return } mmPrintResponse.t.Fatalf("Unexpected call to LoggerMock.PrintResponse. %v %v", request, code) - } // PrintResponseAfterCounter returns a count of finished LoggerMock.PrintResponse invocations @@ -1488,7 +1481,7 @@ func (mmWarning *mLoggerMockWarning) Return() *LoggerMock { return mmWarning.mock } -//Set uses given function f to mock the Logger.Warning method +// Set uses given function f to mock the Logger.Warning method func (mmWarning *mLoggerMockWarning) Set(f func(a ...any)) *LoggerMock { if mmWarning.defaultExpectation != nil { mmWarning.mock.t.Fatalf("Default expectation is already set for the Logger.Warning method") @@ -1541,7 +1534,6 @@ func (mmWarning *LoggerMock) Warning(a ...any) { return } mmWarning.t.Fatalf("Unexpected call to LoggerMock.Warning. %v", a) - } // WarningAfterCounter returns a count of finished LoggerMock.Warning invocations @@ -1676,7 +1668,7 @@ func (mmWarningf *mLoggerMockWarningf) Return() *LoggerMock { return mmWarningf.mock } -//Set uses given function f to mock the Logger.Warningf method +// Set uses given function f to mock the Logger.Warningf method func (mmWarningf *mLoggerMockWarningf) Set(f func(template string, a ...any)) *LoggerMock { if mmWarningf.defaultExpectation != nil { mmWarningf.mock.t.Fatalf("Default expectation is already set for the Logger.Warningf method") @@ -1729,7 +1721,6 @@ func (mmWarningf *LoggerMock) Warningf(template string, a ...any) { return } mmWarningf.t.Fatalf("Unexpected call to LoggerMock.Warningf. %v %v", template, a) - } // WarningfAfterCounter returns a count of finished LoggerMock.Warningf invocations diff --git a/testing/mocks/replacer_factory_mock.go b/testing/mocks/replacer_factory_mock.go index 9d64b1c6..7278ef05 100644 --- a/testing/mocks/replacer_factory_mock.go +++ b/testing/mocks/replacer_factory_mock.go @@ -111,7 +111,7 @@ func (mmMake *mReplacerFactoryMockMake) Return(rp1 *mm_urlreplacer.Replacer, rp2 return mmMake.mock } -//Set uses given function f to mock the ReplacerFactory.Make method +// Set uses given function f to mock the ReplacerFactory.Make method func (mmMake *mReplacerFactoryMockMake) Set(f func(requestURL *url.URL) (rp1 *mm_urlreplacer.Replacer, rp2 *mm_urlreplacer.Replacer, err error)) *ReplacerFactoryMock { if mmMake.defaultExpectation != nil { mmMake.mock.t.Fatalf("Default expectation is already set for the ReplacerFactory.Make method") diff --git a/testing/mocks/writer_mock.go b/testing/mocks/writer_mock.go index 85ad6148..ec605adb 100644 --- a/testing/mocks/writer_mock.go +++ b/testing/mocks/writer_mock.go @@ -108,7 +108,7 @@ func (mmWrite *mWriterMockWrite) Return(n int, err error) *WriterMock { return mmWrite.mock } -//Set uses given function f to mock the Writer.Write method +// Set uses given function f to mock the Writer.Write method func (mmWrite *mWriterMockWrite) Set(f func(p []byte) (n int, err error)) *WriterMock { if mmWrite.defaultExpectation != nil { mmWrite.mock.t.Fatalf("Default expectation is already set for the Writer.Write method") diff --git a/testing/testutils/certs.go b/testing/testutils/certs.go index 19361995..c7bce2dd 100644 --- a/testing/testutils/certs.go +++ b/testing/testutils/certs.go @@ -14,6 +14,8 @@ import ( "path" "testing" "time" + + "github.com/spf13/afero" ) type Certs struct { @@ -23,14 +25,18 @@ type Certs struct { KeyPath string } -func WithTmpCerts(action func(t *testing.T, certs *Certs)) func(t *testing.T) { +func WithTmpCerts(fs afero.Fs, action func(t *testing.T, certs *Certs)) func(t *testing.T) { + if fs == nil { + fs = afero.NewOsFs() + } + return func(t *testing.T) { - certs := certSetup(t) + certs := certSetup(t, fs) action(t, certs) } } -func certSetup(t *testing.T) *Certs { +func certSetup(t *testing.T, fs afero.Fs) *Certs { t.Helper() now := time.Now() @@ -100,11 +106,13 @@ func certSetup(t *testing.T) *Certs { }) CheckNoError(t, err) - tmpDir := t.TempDir() + tmpDir, err := afero.TempDir(fs, "", "uncors_") + CheckNoError(t, err) + certPath := path.Join(tmpDir, "test.cert") keyPath := path.Join(tmpDir, "test.key") - err = os.WriteFile(certPath, certPEM, os.ModePerm) + err = afero.WriteFile(fs, certPath, certPEM, os.ModePerm) CheckNoError(t, err) privateKeyPEM := pem.EncodeToMemory(&pem.Block{ @@ -113,7 +121,7 @@ func certSetup(t *testing.T) *Certs { }) CheckNoError(t, err) - err = os.WriteFile(keyPath, privateKeyPEM, os.ModePerm) + err = afero.WriteFile(fs, keyPath, privateKeyPEM, os.ModePerm) CheckNoError(t, err) serverCert, err := tls.X509KeyPair(certPEM, privateKeyPEM)