diff --git a/Dockerfile b/Dockerfile index cccdd61b0..07175f1cf 100644 --- a/Dockerfile +++ b/Dockerfile @@ -197,6 +197,7 @@ ENV VPN_SERVICE_PROVIDER=pia \ # Control server HTTP_CONTROL_SERVER_LOG=on \ HTTP_CONTROL_SERVER_ADDRESS=":8000" \ + HTTP_CONTROL_SERVER_AUTH_CONFIG_FILEPATH=/gluetun/auth/config.toml \ # Server data updater UPDATER_PERIOD=0 \ UPDATER_MIN_RATIO=0.8 \ diff --git a/cmd/gluetun/main.go b/cmd/gluetun/main.go index 0bfa2d535..8ccfa15d1 100644 --- a/cmd/gluetun/main.go +++ b/cmd/gluetun/main.go @@ -161,12 +161,14 @@ func _main(ctx context.Context, buildInfo models.BuildInformation, return cli.Update(ctx, args[2:], logger) case "format-servers": return cli.FormatServers(args[2:]) + case "genkey": + return cli.GenKey(args[2:]) default: return fmt.Errorf("%w: %s", errCommandUnknown, args[1]) } } - announcementExp, err := time.Parse(time.RFC3339, "2023-07-01T00:00:00Z") + announcementExp, err := time.Parse(time.RFC3339, "2024-12-01T00:00:00Z") if err != nil { return err } @@ -177,7 +179,7 @@ func _main(ctx context.Context, buildInfo models.BuildInformation, Version: buildInfo.Version, Commit: buildInfo.Commit, Created: buildInfo.Created, - Announcement: "Wiki moved to https://github.com/qdm12/gluetun-wiki", + Announcement: "All control server routes will become private by default after the v3.41.0 release", AnnounceExp: announcementExp, // Sponsor information PaypalUser: "qmcgaw", @@ -474,6 +476,7 @@ func _main(ctx context.Context, buildInfo models.BuildInformation, "http server", goroutine.OptionTimeout(defaultShutdownTimeout)) httpServer, err := server.New(httpServerCtx, controlServerAddress, controlServerLogging, logger.New(log.SetComponent("http server")), + allSettings.ControlServer.AuthFilePath, buildInfo, vpnLooper, portForwardLooper, unboundLooper, updaterLooper, publicIPLooper, storage, ipv6Supported) if err != nil { @@ -595,6 +598,7 @@ type clier interface { OpenvpnConfig(logger cli.OpenvpnConfigLogger, reader *reader.Reader, ipv6Checker cli.IPv6Checker) error HealthCheck(ctx context.Context, reader *reader.Reader, warner cli.Warner) error Update(ctx context.Context, args []string, logger cli.UpdaterLogger) error + GenKey(args []string) error } type Tun interface { diff --git a/go.mod b/go.mod index 5a1189233..2812fa1b5 100644 --- a/go.mod +++ b/go.mod @@ -8,6 +8,7 @@ require ( github.com/golang/mock v1.6.0 github.com/klauspost/compress v1.17.8 github.com/klauspost/pgzip v1.2.6 + github.com/pelletier/go-toml/v2 v2.2.2 github.com/qdm12/dns v1.11.0 github.com/qdm12/golibs v0.0.0-20210822203818-5c568b0777b6 github.com/qdm12/gosettings v0.4.2 diff --git a/go.sum b/go.sum index 6678999e4..069186f97 100644 --- a/go.sum +++ b/go.sum @@ -83,6 +83,8 @@ github.com/mr-tron/base58 v1.2.0 h1:T/HDJBh4ZCPbU39/+c3rRvE0uKBQlU27+QI8LJ4t64o= github.com/mr-tron/base58 v1.2.0/go.mod h1:BinMc/sQntlIE1frQmRFPUoPA1Zkr8VRgBdjWI2mNwc= github.com/pborman/uuid v1.2.0/go.mod h1:X/NO0urCmaxf9VXbdlT7C2Yzkj2IKimNn4k+gtPdI/k= github.com/pelletier/go-buffruneio v0.2.0/go.mod h1:JkE26KsDizTr40EUHkXVtNPvgGtbSNq5BcowyYOWdKo= +github.com/pelletier/go-toml/v2 v2.2.2 h1:aYUidT7k73Pcl9nb2gScu7NSrKCSHIDE89b3+6Wq+LM= +github.com/pelletier/go-toml/v2 v2.2.2/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h4qDAS4n929Rs= github.com/phayes/permbits v0.0.0-20190612203442-39d7c581d2ee/go.mod h1:3uODdxMgOaPYeWU7RzZLxVtJHZ/x1f/iHkBZuKJDzuY= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= @@ -113,10 +115,16 @@ github.com/sergi/go-diff v1.0.0/go.mod h1:0CfEIISq7TuYL3j771MWULgwwjU+GofnZX9QAm github.com/src-d/gcfg v1.4.0/go.mod h1:p/UMsR43ujA89BJY9duynAwIpvqEujIH/jFlfL7jWoI= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.2.0/go.mod h1:qt09Ya8vawLte6SNmTgCsAVtYtaKzEcn8ATUoHMkEqE= +github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= +github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= +github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= +github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/ulikunitz/xz v0.5.11 h1:kpFauv27b6ynzBNT/Xy+1k+fK4WswhN/6PN5WhFAGw8= diff --git a/internal/cli/genkey.go b/internal/cli/genkey.go new file mode 100644 index 000000000..ac161cf00 --- /dev/null +++ b/internal/cli/genkey.go @@ -0,0 +1,66 @@ +package cli + +import ( + "crypto/rand" + "flag" + "fmt" +) + +func (c *CLI) GenKey(args []string) (err error) { + flagSet := flag.NewFlagSet("genkey", flag.ExitOnError) + err = flagSet.Parse(args) + if err != nil { + return fmt.Errorf("parsing flags: %w", err) + } + + const keyLength = 128 / 8 + keyBytes := make([]byte, keyLength) + + _, _ = rand.Read(keyBytes) + + key := base58Encode(keyBytes) + fmt.Println(key) + + return nil +} + +func base58Encode(data []byte) string { + const alphabet = "123456789ABCDEFGHJKLMNPQRSTUVWXYZabcdefghijkmnopqrstuvwxyz" + const radix = 58 + + zcount := 0 + for zcount < len(data) && data[zcount] == 0 { + zcount++ + } + + // integer simplification of ceil(log(256)/log(58)) + ceilLog256Div58 := (len(data)-zcount)*555/406 + 1 //nolint:gomnd + size := zcount + ceilLog256Div58 + + output := make([]byte, size) + + high := size - 1 + for _, b := range data { + i := size - 1 + for carry := uint32(b); i > high || carry != 0; i-- { + carry += 256 * uint32(output[i]) //nolint:gomnd + output[i] = byte(carry % radix) + carry /= radix + } + high = i + } + + // Determine the additional "zero-gap" in the output buffer + additionalZeroGapEnd := zcount + for additionalZeroGapEnd < size && output[additionalZeroGapEnd] == 0 { + additionalZeroGapEnd++ + } + + val := output[additionalZeroGapEnd-zcount:] + size = len(val) + for i := range val { + output[i] = alphabet[val[i]] + } + + return string(output[:size]) +} diff --git a/internal/configuration/settings/server.go b/internal/configuration/settings/server.go index 82f2773a8..155d5cdb3 100644 --- a/internal/configuration/settings/server.go +++ b/internal/configuration/settings/server.go @@ -19,6 +19,11 @@ type ControlServer struct { // Log can be true or false to enable logging on requests. // It cannot be nil in the internal state. Log *bool + // AuthFilePath is the path to the file containing the authentication + // configuration for the middleware. + // It cannot be empty in the internal state and defaults to + // /gluetun/auth/config.toml. + AuthFilePath string } func (c ControlServer) validate() (err error) { @@ -44,8 +49,9 @@ func (c ControlServer) validate() (err error) { func (c *ControlServer) copy() (copied ControlServer) { return ControlServer{ - Address: gosettings.CopyPointer(c.Address), - Log: gosettings.CopyPointer(c.Log), + Address: gosettings.CopyPointer(c.Address), + Log: gosettings.CopyPointer(c.Log), + AuthFilePath: c.AuthFilePath, } } @@ -55,11 +61,13 @@ func (c *ControlServer) copy() (copied ControlServer) { func (c *ControlServer) overrideWith(other ControlServer) { c.Address = gosettings.OverrideWithPointer(c.Address, other.Address) c.Log = gosettings.OverrideWithPointer(c.Log, other.Log) + c.AuthFilePath = gosettings.OverrideWithComparable(c.AuthFilePath, other.AuthFilePath) } func (c *ControlServer) setDefaults() { c.Address = gosettings.DefaultPointer(c.Address, ":8000") c.Log = gosettings.DefaultPointer(c.Log, true) + c.AuthFilePath = gosettings.DefaultComparable(c.AuthFilePath, "/gluetun/auth/config.toml") } func (c ControlServer) String() string { @@ -70,6 +78,7 @@ func (c ControlServer) toLinesNode() (node *gotree.Node) { node = gotree.New("Control server settings:") node.Appendf("Listening address: %s", *c.Address) node.Appendf("Logging: %s", gosettings.BoolToYesNo(c.Log)) + node.Appendf("Authentication file path: %s", c.AuthFilePath) return node } @@ -78,6 +87,10 @@ func (c *ControlServer) read(r *reader.Reader) (err error) { if err != nil { return err } + c.Address = r.Get("HTTP_CONTROL_SERVER_ADDRESS") + + c.AuthFilePath = r.String("HTTP_CONTROL_SERVER_AUTH_CONFIG_FILEPATH") + return nil } diff --git a/internal/configuration/settings/settings_test.go b/internal/configuration/settings/settings_test.go index 9fa0f3517..bf4805bdf 100644 --- a/internal/configuration/settings/settings_test.go +++ b/internal/configuration/settings/settings_test.go @@ -78,7 +78,8 @@ func Test_Settings_String(t *testing.T) { | └── Enabled: no ├── Control server settings: | ├── Listening address: :8000 -| └── Logging: yes +| ├── Logging: yes +| └── Authentication file path: /gluetun/auth/config.toml ├── OS Alpine settings: | ├── Process UID: 1000 | └── Process GID: 1000 diff --git a/internal/server/handler.go b/internal/server/handler.go index 29ae2ca5b..4e68eb5da 100644 --- a/internal/server/handler.go +++ b/internal/server/handler.go @@ -2,13 +2,17 @@ package server import ( "context" + "fmt" "net/http" "strings" "github.com/qdm12/gluetun/internal/models" + "github.com/qdm12/gluetun/internal/server/middlewares/auth" + "github.com/qdm12/gluetun/internal/server/middlewares/log" ) -func newHandler(ctx context.Context, logger infoWarner, logging bool, +func newHandler(ctx context.Context, logger Logger, logging bool, + authSettings auth.Settings, buildInfo models.BuildInformation, vpnLooper VPNLooper, pfGetter PortForwardedGetter, @@ -17,7 +21,7 @@ func newHandler(ctx context.Context, logger infoWarner, logging bool, publicIPLooper PublicIPLoop, storage Storage, ipv6Supported bool, -) http.Handler { +) (httpHandler http.Handler, err error) { handler := &handler{} vpn := newVPNHandler(ctx, vpnLooper, storage, ipv6Supported, logger) @@ -29,16 +33,25 @@ func newHandler(ctx context.Context, logger infoWarner, logging bool, handler.v0 = newHandlerV0(ctx, logger, vpnLooper, unboundLooper, updaterLooper) handler.v1 = newHandlerV1(logger, buildInfo, vpn, openvpn, dns, updater, publicip) - handlerWithLog := withLogMiddleware(handler, logger, logging) - handler.setLogEnabled = handlerWithLog.setEnabled + authMiddleware, err := auth.New(authSettings, logger) + if err != nil { + return nil, fmt.Errorf("creating auth middleware: %w", err) + } - return handlerWithLog + middlewares := []func(http.Handler) http.Handler{ + authMiddleware, + log.New(logger, logging), + } + httpHandler = handler + for _, middleware := range middlewares { + httpHandler = middleware(httpHandler) + } + return httpHandler, nil } type handler struct { - v0 http.Handler - v1 http.Handler - setLogEnabled func(enabled bool) + v0 http.Handler + v1 http.Handler } func (h *handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { diff --git a/internal/server/logger.go b/internal/server/logger.go index 4b099c07a..4e847aa36 100644 --- a/internal/server/logger.go +++ b/internal/server/logger.go @@ -1,8 +1,10 @@ package server type Logger interface { + Debugf(format string, args ...any) infoer warner + Warnf(format string, args ...any) errorer } diff --git a/internal/server/middlewares/auth/apikey.go b/internal/server/middlewares/auth/apikey.go new file mode 100644 index 000000000..f9fd53834 --- /dev/null +++ b/internal/server/middlewares/auth/apikey.go @@ -0,0 +1,36 @@ +package auth + +import ( + "crypto/sha256" + "crypto/subtle" + "net/http" +) + +type apiKeyMethod struct { + apiKeyDigest [32]byte +} + +func newAPIKeyMethod(apiKey string) *apiKeyMethod { + return &apiKeyMethod{ + apiKeyDigest: sha256.Sum256([]byte(apiKey)), + } +} + +// equal returns true if another auth checker is equal. +// This is used to deduplicate checkers for a particular route. +func (a *apiKeyMethod) equal(other authorizationChecker) bool { + otherTokenMethod, ok := other.(*apiKeyMethod) + if !ok { + return false + } + return a.apiKeyDigest == otherTokenMethod.apiKeyDigest +} + +func (a *apiKeyMethod) isAuthorized(_ http.Header, request *http.Request) bool { + xAPIKey := request.Header.Get("X-API-Key") + if xAPIKey == "" { + xAPIKey = request.URL.Query().Get("api_key") + } + xAPIKeyDigest := sha256.Sum256([]byte(xAPIKey)) + return subtle.ConstantTimeCompare(xAPIKeyDigest[:], a.apiKeyDigest[:]) == 1 +} diff --git a/internal/server/middlewares/auth/basic.go b/internal/server/middlewares/auth/basic.go new file mode 100644 index 000000000..7017ed831 --- /dev/null +++ b/internal/server/middlewares/auth/basic.go @@ -0,0 +1,37 @@ +package auth + +import ( + "crypto/sha256" + "crypto/subtle" + "net/http" +) + +type basicAuthMethod struct { + authDigest [32]byte +} + +func newBasicAuthMethod(username, password string) *basicAuthMethod { + return &basicAuthMethod{ + authDigest: sha256.Sum256([]byte(username + password)), + } +} + +// equal returns true if another auth checker is equal. +// This is used to deduplicate checkers for a particular route. +func (a *basicAuthMethod) equal(other authorizationChecker) bool { + otherBasicMethod, ok := other.(*basicAuthMethod) + if !ok { + return false + } + return a.authDigest == otherBasicMethod.authDigest +} + +func (a *basicAuthMethod) isAuthorized(headers http.Header, request *http.Request) bool { + username, password, ok := request.BasicAuth() + if !ok { + headers.Set("WWW-Authenticate", `Basic realm="restricted", charset="UTF-8"`) + return false + } + requestAuthDigest := sha256.Sum256([]byte(username + password)) + return subtle.ConstantTimeCompare(a.authDigest[:], requestAuthDigest[:]) == 1 +} diff --git a/internal/server/middlewares/auth/configfile.go b/internal/server/middlewares/auth/configfile.go new file mode 100644 index 000000000..1722a9107 --- /dev/null +++ b/internal/server/middlewares/auth/configfile.go @@ -0,0 +1,35 @@ +package auth + +import ( + "errors" + "fmt" + "os" + + "github.com/pelletier/go-toml/v2" +) + +// Read reads the toml file specified by the filepath given. +// If the file does not exist, it returns empty settings and no error. +func Read(filepath string) (settings Settings, err error) { + file, err := os.Open(filepath) + if err != nil { + if errors.Is(err, os.ErrNotExist) { + return Settings{}, nil + } + return settings, fmt.Errorf("opening file: %w", err) + } + decoder := toml.NewDecoder(file) + decoder.DisallowUnknownFields() + err = decoder.Decode(&settings) + if err == nil { + return settings, nil + } + + strictErr := new(toml.StrictMissingError) + ok := errors.As(err, &strictErr) + if !ok { + return settings, fmt.Errorf("toml decoding file: %w", err) + } + return settings, fmt.Errorf("toml decoding file: %w:\n%s", + strictErr, strictErr.String()) +} diff --git a/internal/server/middlewares/auth/configfile_test.go b/internal/server/middlewares/auth/configfile_test.go new file mode 100644 index 000000000..4dcc30097 --- /dev/null +++ b/internal/server/middlewares/auth/configfile_test.go @@ -0,0 +1,80 @@ +package auth + +import ( + "io/fs" + "os" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// Read reads the toml file specified by the filepath given. +func Test_Read(t *testing.T) { + t.Parallel() + + testCases := map[string]struct { + fileContent string + settings Settings + errMessage string + }{ + "empty_file": {}, + "malformed_toml": { + fileContent: "this is not a toml file", + errMessage: `toml decoding file: toml: expected character =`, + }, + "unknown_field": { + fileContent: `unknown = "what is this"`, + errMessage: `toml decoding file: strict mode: fields in the document are missing in the target struct: +1| unknown = "what is this" + | ~~~~~~~ missing field`, + }, + "filled_settings": { + fileContent: `[[roles]] +name = "public" +auth = "none" +routes = ["GET /v1/vpn/status", "PUT /v1/vpn/status"] + +[[roles]] +name = "client" +auth = "apikey" +apikey = "xyz" +routes = ["GET /v1/vpn/status"] +`, + settings: Settings{ + Roles: []Role{{ + Name: "public", + Auth: AuthNone, + Routes: []string{"GET /v1/vpn/status", "PUT /v1/vpn/status"}, + }, { + Name: "client", + Auth: AuthAPIKey, + APIKey: "xyz", + Routes: []string{"GET /v1/vpn/status"}, + }}, + }, + }, + } + + for name, testCase := range testCases { + testCase := testCase + t.Run(name, func(t *testing.T) { + t.Parallel() + + tempDir := t.TempDir() + filepath := tempDir + "/config.toml" + const permissions fs.FileMode = 0600 + err := os.WriteFile(filepath, []byte(testCase.fileContent), permissions) + require.NoError(t, err) + + settings, err := Read(filepath) + + assert.Equal(t, testCase.settings, settings) + if testCase.errMessage != "" { + assert.EqualError(t, err, testCase.errMessage) + } else { + assert.NoError(t, err) + } + }) + } +} diff --git a/internal/server/middlewares/auth/format.go b/internal/server/middlewares/auth/format.go new file mode 100644 index 000000000..26d858e4f --- /dev/null +++ b/internal/server/middlewares/auth/format.go @@ -0,0 +1,22 @@ +package auth + +func andStrings(strings []string) (result string) { + return joinStrings(strings, "and") +} + +func joinStrings(strings []string, lastJoin string) (result string) { + if len(strings) == 0 { + return "" + } + + result = strings[0] + for i := 1; i < len(strings); i++ { + if i < len(strings)-1 { + result += ", " + strings[i] + } else { + result += " " + lastJoin + " " + strings[i] + } + } + + return result +} diff --git a/internal/server/middlewares/auth/interfaces.go b/internal/server/middlewares/auth/interfaces.go new file mode 100644 index 000000000..7a6901d23 --- /dev/null +++ b/internal/server/middlewares/auth/interfaces.go @@ -0,0 +1,6 @@ +package auth + +type DebugLogger interface { + Debugf(format string, args ...any) + Warnf(format string, args ...any) +} diff --git a/internal/server/middlewares/auth/interfaces_local.go b/internal/server/middlewares/auth/interfaces_local.go new file mode 100644 index 000000000..31a0aeed4 --- /dev/null +++ b/internal/server/middlewares/auth/interfaces_local.go @@ -0,0 +1,8 @@ +package auth + +import "net/http" + +type authorizationChecker interface { + equal(other authorizationChecker) bool + isAuthorized(headers http.Header, request *http.Request) bool +} diff --git a/internal/server/middlewares/auth/lookup.go b/internal/server/middlewares/auth/lookup.go new file mode 100644 index 000000000..d02c433b3 --- /dev/null +++ b/internal/server/middlewares/auth/lookup.go @@ -0,0 +1,47 @@ +package auth + +import ( + "fmt" +) + +type internalRole struct { + name string + checker authorizationChecker +} + +func settingsToLookupMap(settings Settings) (routeToRoles map[string][]internalRole, err error) { + routeToRoles = make(map[string][]internalRole) + for _, role := range settings.Roles { + var checker authorizationChecker + switch role.Auth { + case AuthNone: + checker = newNoneMethod() + case AuthAPIKey: + checker = newAPIKeyMethod(role.APIKey) + case AuthBasic: + checker = newBasicAuthMethod(role.Username, role.Password) + default: + return nil, fmt.Errorf("%w: %s", ErrMethodNotSupported, role.Auth) + } + + iRole := internalRole{ + name: role.Name, + checker: checker, + } + for _, route := range role.Routes { + checkerExists := false + for _, role := range routeToRoles[route] { + if role.checker.equal(iRole.checker) { + checkerExists = true + break + } + } + if checkerExists { + // even if the role name is different, if the checker is the same, skip it. + continue + } + routeToRoles[route] = append(routeToRoles[route], iRole) + } + } + return routeToRoles, nil +} diff --git a/internal/server/middlewares/auth/lookup_test.go b/internal/server/middlewares/auth/lookup_test.go new file mode 100644 index 000000000..225f7f8ce --- /dev/null +++ b/internal/server/middlewares/auth/lookup_test.go @@ -0,0 +1,60 @@ +package auth + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +// Read reads the toml file specified by the filepath given. +func Test_settingsToLookupMap(t *testing.T) { + t.Parallel() + + testCases := map[string]struct { + settings Settings + routeToRoles map[string][]internalRole + errWrapped error + errMessage string + }{ + "empty_settings": { + routeToRoles: map[string][]internalRole{}, + }, + "auth_method_not_supported": { + settings: Settings{ + Roles: []Role{{Name: "a", Auth: "bad"}}, + }, + errWrapped: ErrMethodNotSupported, + errMessage: "authentication method not supported: bad", + }, + "success": { + settings: Settings{ + Roles: []Role{ + {Name: "a", Auth: AuthNone, Routes: []string{"GET /path"}}, + {Name: "b", Auth: AuthNone, Routes: []string{"GET /path", "PUT /path"}}, + }, + }, + routeToRoles: map[string][]internalRole{ + "GET /path": { + {name: "a", checker: newNoneMethod()}, // deduplicated method + }, + "PUT /path": { + {name: "b", checker: newNoneMethod()}, + }}, + }, + } + + for name, testCase := range testCases { + testCase := testCase + t.Run(name, func(t *testing.T) { + t.Parallel() + + routeToRoles, err := settingsToLookupMap(testCase.settings) + + assert.Equal(t, testCase.routeToRoles, routeToRoles) + assert.ErrorIs(t, err, testCase.errWrapped) + if testCase.errWrapped != nil { + assert.EqualError(t, err, testCase.errMessage) + } + }) + } +} diff --git a/internal/server/middlewares/auth/middleware.go b/internal/server/middlewares/auth/middleware.go new file mode 100644 index 000000000..7a6c18bcd --- /dev/null +++ b/internal/server/middlewares/auth/middleware.go @@ -0,0 +1,111 @@ +package auth + +import ( + "fmt" + "net/http" +) + +func New(settings Settings, debugLogger DebugLogger) ( + middleware func(http.Handler) http.Handler, + err error) { + routeToRoles, err := settingsToLookupMap(settings) + if err != nil { + return nil, fmt.Errorf("converting settings to lookup maps: %w", err) + } + + //nolint:goconst + return func(handler http.Handler) http.Handler { + return &authHandler{ + childHandler: handler, + routeToRoles: routeToRoles, + unprotectedRoutes: map[string]struct{}{ + http.MethodGet + " /openvpn/actions/restart": {}, + http.MethodGet + " /unbound/actions/restart": {}, + http.MethodGet + " /updater/restart": {}, + http.MethodGet + " /v1/version": {}, + http.MethodGet + " /v1/vpn/status": {}, + http.MethodPut + " /v1/vpn/status": {}, + // GET /v1/vpn/settings is protected by default + // PUT /v1/vpn/settings is protected by default + http.MethodGet + " /v1/openvpn/status": {}, + http.MethodPut + " /v1/openvpn/status": {}, + http.MethodGet + " /v1/openvpn/portforwarded": {}, + // GET /v1/openvpn/settings is protected by default + http.MethodGet + " /v1/dns/status": {}, + http.MethodPut + " /v1/dns/status": {}, + http.MethodGet + " /v1/updater/status": {}, + http.MethodPut + " /v1/updater/status": {}, + http.MethodGet + " /v1/publicip/ip": {}, + }, + logger: debugLogger, + } + }, nil +} + +type authHandler struct { + childHandler http.Handler + routeToRoles map[string][]internalRole + unprotectedRoutes map[string]struct{} // TODO v3.41.0 remove + logger DebugLogger +} + +func (h *authHandler) ServeHTTP(writer http.ResponseWriter, request *http.Request) { + route := request.Method + " " + request.URL.Path + roles := h.routeToRoles[route] + if len(roles) == 0 { + h.logger.Debugf("no authentication role defined for route %s", route) + http.Error(writer, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized) + return + } + + responseHeader := make(http.Header, 0) + for _, role := range roles { + if !role.checker.isAuthorized(responseHeader, request) { + continue + } + + h.warnIfUnprotectedByDefault(role, route) // TODO v3.41.0 remove + + h.logger.Debugf("access to route %s authorized for role %s", route, role.name) + h.childHandler.ServeHTTP(writer, request) + return + } + + // Flush out response headers if all roles failed to authenticate + for headerKey, headerValues := range responseHeader { + for _, headerValue := range headerValues { + writer.Header().Add(headerKey, headerValue) + } + } + + allRoleNames := make([]string, len(roles)) + for i, role := range roles { + allRoleNames[i] = role.name + } + h.logger.Debugf("access to route %s unauthorized after checking for roles %s", + route, andStrings(allRoleNames)) + http.Error(writer, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized) +} + +func (h *authHandler) warnIfUnprotectedByDefault(role internalRole, route string) { + // TODO v3.41.0 remove + if role.name != "public" { + // custom role name, allow none authentication to be specified + return + } + _, isNoneChecker := role.checker.(*noneMethod) + if !isNoneChecker { + // not the none authentication method + return + } + _, isUnprotectedByDefault := h.unprotectedRoutes[route] + if !isUnprotectedByDefault { + // route is not unprotected by default, so this is a user decision + return + } + h.logger.Warnf("route %s is unprotected by default, "+ + "please set up authentication following the documentation at "+ + "https://github.com/qdm12/gluetun-wiki/setup/advanced/control-server.md#authentication "+ + "since this will become no longer publicly accessible after release v3.40.", + route) +} diff --git a/internal/server/middlewares/auth/middleware_test.go b/internal/server/middlewares/auth/middleware_test.go new file mode 100644 index 000000000..5f9f75ccf --- /dev/null +++ b/internal/server/middlewares/auth/middleware_test.go @@ -0,0 +1,124 @@ +package auth + +import ( + "context" + "io" + "net/http" + "net/http/httptest" + "net/url" + "testing" + + "github.com/golang/mock/gomock" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func Test_authHandler_ServeHTTP(t *testing.T) { + t.Parallel() + + testCases := map[string]struct { + settings Settings + makeLogger func(ctrl *gomock.Controller) *MockDebugLogger + requestMethod string + requestPath string + statusCode int + responseBody string + }{ + "route_has_no_role": { + settings: Settings{ + Roles: []Role{ + {Name: "role1", Auth: AuthNone, Routes: []string{"GET /a"}}, + }, + }, + makeLogger: func(ctrl *gomock.Controller) *MockDebugLogger { + logger := NewMockDebugLogger(ctrl) + logger.EXPECT().Debugf("no authentication role defined for route %s", "GET /b") + return logger + }, + requestMethod: http.MethodGet, + requestPath: "/b", + statusCode: http.StatusUnauthorized, + responseBody: "Unauthorized\n", + }, + "authorized_unprotected_by_default": { + settings: Settings{ + Roles: []Role{ + {Name: "public", Auth: AuthNone, Routes: []string{"GET /v1/vpn/status"}}, + }, + }, + makeLogger: func(ctrl *gomock.Controller) *MockDebugLogger { + logger := NewMockDebugLogger(ctrl) + logger.EXPECT().Warnf("route %s is unprotected by default, "+ + "please set up authentication following the documentation at "+ + "https://github.com/qdm12/gluetun-wiki/setup/advanced/control-server.md#authentication "+ + "since this will become no longer publicly accessible after release v3.40.", + "GET /v1/vpn/status") + logger.EXPECT().Debugf("access to route %s authorized for role %s", + "GET /v1/vpn/status", "public") + return logger + }, + requestMethod: http.MethodGet, + requestPath: "/v1/vpn/status", + statusCode: http.StatusOK, + }, + "authorized_none": { + settings: Settings{ + Roles: []Role{ + {Name: "role1", Auth: AuthNone, Routes: []string{"GET /a"}}, + }, + }, + makeLogger: func(ctrl *gomock.Controller) *MockDebugLogger { + logger := NewMockDebugLogger(ctrl) + logger.EXPECT().Debugf("access to route %s authorized for role %s", + "GET /a", "role1") + return logger + }, + requestMethod: http.MethodGet, + requestPath: "/a", + statusCode: http.StatusOK, + }, + } + + for name, testCase := range testCases { + testCase := testCase + t.Run(name, func(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + + var debugLogger DebugLogger + if testCase.makeLogger != nil { + debugLogger = testCase.makeLogger(ctrl) + } + middleware, err := New(testCase.settings, debugLogger) + require.NoError(t, err) + + childHandler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + }) + handler := middleware(childHandler) + + server := httptest.NewServer(handler) + t.Cleanup(server.Close) + + client := server.Client() + + requestURL, err := url.JoinPath(server.URL, testCase.requestPath) + require.NoError(t, err) + request, err := http.NewRequestWithContext(context.Background(), + testCase.requestMethod, requestURL, nil) + require.NoError(t, err) + + response, err := client.Do(request) + require.NoError(t, err) + t.Cleanup(func() { + err = response.Body.Close() + assert.NoError(t, err) + }) + + assert.Equal(t, testCase.statusCode, response.StatusCode) + body, err := io.ReadAll(response.Body) + require.NoError(t, err) + assert.Equal(t, testCase.responseBody, string(body)) + }) + } +} diff --git a/internal/server/middlewares/auth/mocks_generate_test.go b/internal/server/middlewares/auth/mocks_generate_test.go new file mode 100644 index 000000000..d9ce4b052 --- /dev/null +++ b/internal/server/middlewares/auth/mocks_generate_test.go @@ -0,0 +1,3 @@ +package auth + +//go:generate mockgen -destination=mocks_test.go -package=$GOPACKAGE . DebugLogger diff --git a/internal/server/middlewares/auth/mocks_test.go b/internal/server/middlewares/auth/mocks_test.go new file mode 100644 index 000000000..37538c5bf --- /dev/null +++ b/internal/server/middlewares/auth/mocks_test.go @@ -0,0 +1,68 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/qdm12/gluetun/internal/server/middlewares/auth (interfaces: DebugLogger) + +// Package auth is a generated GoMock package. +package auth + +import ( + reflect "reflect" + + gomock "github.com/golang/mock/gomock" +) + +// MockDebugLogger is a mock of DebugLogger interface. +type MockDebugLogger struct { + ctrl *gomock.Controller + recorder *MockDebugLoggerMockRecorder +} + +// MockDebugLoggerMockRecorder is the mock recorder for MockDebugLogger. +type MockDebugLoggerMockRecorder struct { + mock *MockDebugLogger +} + +// NewMockDebugLogger creates a new mock instance. +func NewMockDebugLogger(ctrl *gomock.Controller) *MockDebugLogger { + mock := &MockDebugLogger{ctrl: ctrl} + mock.recorder = &MockDebugLoggerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockDebugLogger) EXPECT() *MockDebugLoggerMockRecorder { + return m.recorder +} + +// Debugf mocks base method. +func (m *MockDebugLogger) Debugf(arg0 string, arg1 ...interface{}) { + m.ctrl.T.Helper() + varargs := []interface{}{arg0} + for _, a := range arg1 { + varargs = append(varargs, a) + } + m.ctrl.Call(m, "Debugf", varargs...) +} + +// Debugf indicates an expected call of Debugf. +func (mr *MockDebugLoggerMockRecorder) Debugf(arg0 interface{}, arg1 ...interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]interface{}{arg0}, arg1...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Debugf", reflect.TypeOf((*MockDebugLogger)(nil).Debugf), varargs...) +} + +// Warnf mocks base method. +func (m *MockDebugLogger) Warnf(arg0 string, arg1 ...interface{}) { + m.ctrl.T.Helper() + varargs := []interface{}{arg0} + for _, a := range arg1 { + varargs = append(varargs, a) + } + m.ctrl.Call(m, "Warnf", varargs...) +} + +// Warnf indicates an expected call of Warnf. +func (mr *MockDebugLoggerMockRecorder) Warnf(arg0 interface{}, arg1 ...interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]interface{}{arg0}, arg1...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Warnf", reflect.TypeOf((*MockDebugLogger)(nil).Warnf), varargs...) +} diff --git a/internal/server/middlewares/auth/none.go b/internal/server/middlewares/auth/none.go new file mode 100644 index 000000000..4d9f82018 --- /dev/null +++ b/internal/server/middlewares/auth/none.go @@ -0,0 +1,20 @@ +package auth + +import "net/http" + +type noneMethod struct{} + +func newNoneMethod() *noneMethod { + return &noneMethod{} +} + +// equal returns true if another auth checker is equal. +// This is used to deduplicate checkers for a particular route. +func (n *noneMethod) equal(other authorizationChecker) bool { + _, ok := other.(*noneMethod) + return ok +} + +func (n *noneMethod) isAuthorized(_ http.Header, _ *http.Request) bool { + return true +} diff --git a/internal/server/middlewares/auth/settings.go b/internal/server/middlewares/auth/settings.go new file mode 100644 index 000000000..70bc5e09e --- /dev/null +++ b/internal/server/middlewares/auth/settings.go @@ -0,0 +1,131 @@ +package auth + +import ( + "errors" + "fmt" + "net/http" + + "github.com/qdm12/gosettings" + "github.com/qdm12/gosettings/validate" +) + +type Settings struct { + // Roles is a list of roles with their associated authentication + // and routes. + Roles []Role +} + +func (s *Settings) SetDefaults() { + s.Roles = gosettings.DefaultSlice(s.Roles, []Role{{ // TODO v3.41.0 leave empty + Name: "public", + Auth: "none", + Routes: []string{ + http.MethodGet + " /openvpn/actions/restart", + http.MethodGet + " /unbound/actions/restart", + http.MethodGet + " /updater/restart", + http.MethodGet + " /v1/version", + http.MethodGet + " /v1/vpn/status", + http.MethodPut + " /v1/vpn/status", + http.MethodGet + " /v1/openvpn/status", + http.MethodPut + " /v1/openvpn/status", + http.MethodGet + " /v1/openvpn/portforwarded", + http.MethodGet + " /v1/dns/status", + http.MethodPut + " /v1/dns/status", + http.MethodGet + " /v1/updater/status", + http.MethodPut + " /v1/updater/status", + http.MethodGet + " /v1/publicip/ip", + }, + }}) +} + +func (s Settings) Validate() (err error) { + for i, role := range s.Roles { + err = role.validate() + if err != nil { + return fmt.Errorf("role %s (%d of %d): %w", + role.Name, i+1, len(s.Roles), err) + } + } + + return nil +} + +const ( + AuthNone = "none" + AuthAPIKey = "apikey" + AuthBasic = "basic" +) + +// Role contains the role name, authentication method name and +// routes that the role can access. +type Role struct { + // Name is the role name and is only used for documentation + // and in the authentication middleware debug logs. + Name string + // Auth is the authentication method to use, which can be 'none' or 'apikey'. + Auth string + // APIKey is the API key to use when using the 'apikey' authentication. + APIKey string + // Username for HTTP Basic authentication method. + Username string + // Password for HTTP Basic authentication method. + Password string + // Routes is a list of routes that the role can access in the format + // "HTTP_METHOD PATH", for example "GET /v1/vpn/status" + Routes []string +} + +var ( + ErrMethodNotSupported = errors.New("authentication method not supported") + ErrAPIKeyEmpty = errors.New("api key is empty") + ErrBasicUsernameEmpty = errors.New("username is empty") + ErrBasicPasswordEmpty = errors.New("password is empty") + ErrRouteNotSupported = errors.New("route not supported by the control server") +) + +func (r Role) validate() (err error) { + err = validate.IsOneOf(r.Auth, AuthNone, AuthAPIKey, AuthBasic) + if err != nil { + return fmt.Errorf("%w: %s", ErrMethodNotSupported, r.Auth) + } + + switch { + case r.Auth == AuthAPIKey && r.APIKey == "": + return fmt.Errorf("for role %s: %w", r.Name, ErrAPIKeyEmpty) + case r.Auth == AuthBasic && r.Username == "": + return fmt.Errorf("for role %s: %w", r.Name, ErrBasicUsernameEmpty) + case r.Auth == AuthBasic && r.Password == "": + return fmt.Errorf("for role %s: %w", r.Name, ErrBasicPasswordEmpty) + } + + for i, route := range r.Routes { + _, ok := validRoutes[route] + if !ok { + return fmt.Errorf("route %d of %d: %w: %s", + i+1, len(r.Routes), ErrRouteNotSupported, route) + } + } + + return nil +} + +// WARNING: do not mutate programmatically. +var validRoutes = map[string]struct{}{ //nolint:gochecknoglobals + http.MethodGet + " /openvpn/actions/restart": {}, + http.MethodGet + " /unbound/actions/restart": {}, + http.MethodGet + " /updater/restart": {}, + http.MethodGet + " /v1/version": {}, + http.MethodGet + " /v1/vpn/status": {}, + http.MethodPut + " /v1/vpn/status": {}, + http.MethodGet + " /v1/vpn/settings": {}, + http.MethodPut + " /v1/vpn/settings": {}, + http.MethodGet + " /v1/openvpn/status": {}, + http.MethodPut + " /v1/openvpn/status": {}, + http.MethodGet + " /v1/openvpn/portforwarded": {}, + http.MethodGet + " /v1/openvpn/settings": {}, + http.MethodGet + " /v1/dns/status": {}, + http.MethodPut + " /v1/dns/status": {}, + http.MethodGet + " /v1/updater/status": {}, + http.MethodPut + " /v1/updater/status": {}, + http.MethodGet + " /v1/publicip/ip": {}, +} diff --git a/internal/server/middlewares/log/interfaces.go b/internal/server/middlewares/log/interfaces.go new file mode 100644 index 000000000..6a6c62dc6 --- /dev/null +++ b/internal/server/middlewares/log/interfaces.go @@ -0,0 +1,5 @@ +package log + +type Logger interface { + Info(message string) +} diff --git a/internal/server/log.go b/internal/server/middlewares/log/middleware.go similarity index 80% rename from internal/server/log.go rename to internal/server/middlewares/log/middleware.go index c70dce46b..adfa9196d 100644 --- a/internal/server/log.go +++ b/internal/server/middlewares/log/middleware.go @@ -1,4 +1,4 @@ -package server +package log import ( "net/http" @@ -7,18 +7,21 @@ import ( "time" ) -func withLogMiddleware(childHandler http.Handler, logger infoer, enabled bool) *logMiddleware { - return &logMiddleware{ - childHandler: childHandler, - logger: logger, - timeNow: time.Now, - enabled: enabled, +func New(logger Logger, enabled bool) ( + middleware func(http.Handler) http.Handler) { + return func(handler http.Handler) http.Handler { + return &logMiddleware{ + childHandler: handler, + logger: logger, + timeNow: time.Now, + enabled: enabled, + } } } type logMiddleware struct { childHandler http.Handler - logger infoer + logger Logger timeNow func() time.Time enabled bool enabledMu sync.RWMutex @@ -39,7 +42,7 @@ func (m *logMiddleware) ServeHTTP(w http.ResponseWriter, r *http.Request) { r.RemoteAddr + " in " + duration.String()) } -func (m *logMiddleware) setEnabled(enabled bool) { +func (m *logMiddleware) SetEnabled(enabled bool) { m.enabledMu.Lock() defer m.enabledMu.Unlock() m.enabled = enabled diff --git a/internal/server/server.go b/internal/server/server.go index 9f31a9465..190f7ce70 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -6,17 +6,31 @@ import ( "github.com/qdm12/gluetun/internal/httpserver" "github.com/qdm12/gluetun/internal/models" + "github.com/qdm12/gluetun/internal/server/middlewares/auth" ) func New(ctx context.Context, address string, logEnabled bool, logger Logger, - buildInfo models.BuildInformation, openvpnLooper VPNLooper, + authConfigPath string, buildInfo models.BuildInformation, openvpnLooper VPNLooper, pfGetter PortForwardedGetter, unboundLooper DNSLoop, updaterLooper UpdaterLooper, publicIPLooper PublicIPLoop, storage Storage, ipv6Supported bool) ( server *httpserver.Server, err error) { - handler := newHandler(ctx, logger, logEnabled, buildInfo, + authSettings, err := auth.Read(authConfigPath) + if err != nil { + return nil, fmt.Errorf("reading auth settings: %w", err) + } + authSettings.SetDefaults() + err = authSettings.Validate() + if err != nil { + return nil, fmt.Errorf("validating auth settings: %w", err) + } + + handler, err := newHandler(ctx, logger, logEnabled, authSettings, buildInfo, openvpnLooper, pfGetter, unboundLooper, updaterLooper, publicIPLooper, storage, ipv6Supported) + if err != nil { + return nil, fmt.Errorf("creating handler: %w", err) + } httpServerSettings := httpserver.Settings{ Address: address,