Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(config): allow invalid server filters #2419

Merged
merged 1 commit into from
Aug 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cmd/gluetun/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ func _main(ctx context.Context, buildInfo models.BuildInformation,
return fmt.Errorf("checking for IPv6 support: %w", err)
}

err = allSettings.Validate(storage, ipv6Supported)
err = allSettings.Validate(storage, ipv6Supported, logger)
if err != nil {
return err
}
Expand Down
2 changes: 1 addition & 1 deletion internal/cli/openvpnconfig.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ func (c *CLI) OpenvpnConfig(logger OpenvpnConfigLogger, reader *reader.Reader,
return fmt.Errorf("checking for IPv6 support: %w", err)
}

if err = allSettings.Validate(storage, ipv6Supported); err != nil {
if err = allSettings.Validate(storage, ipv6Supported, logger); err != nil {
return fmt.Errorf("validating settings: %w", err)
}

Expand Down
5 changes: 5 additions & 0 deletions internal/configuration/settings/interfaces.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
package settings

type Warner interface {
Warn(message string)
}
4 changes: 2 additions & 2 deletions internal/configuration/settings/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ type Provider struct {
}

// TODO v4 remove pointer for receiver (because of Surfshark).
func (p *Provider) validate(vpnType string, storage Storage) (err error) {
func (p *Provider) validate(vpnType string, storage Storage, warner Warner) (err error) {
// Validate Name
var validNames []string
if vpnType == vpn.OpenVPN {
Expand All @@ -48,7 +48,7 @@ func (p *Provider) validate(vpnType string, storage Storage) (err error) {
return fmt.Errorf("%w for Wireguard: %w", ErrVPNProviderNameNotValid, err)
}

err = p.ServerSelection.validate(p.Name, storage)
err = p.ServerSelection.validate(p.Name, storage, warner)
if err != nil {
return fmt.Errorf("server selection: %w", err)
}
Expand Down
66 changes: 51 additions & 15 deletions internal/configuration/settings/serverselection.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,14 +91,14 @@ var (
)

func (ss *ServerSelection) validate(vpnServiceProvider string,
storage Storage) (err error) {
storage Storage, warner Warner) (err error) {
switch ss.VPN {
case vpn.OpenVPN, vpn.Wireguard:
default:
return fmt.Errorf("%w: %s", ErrVPNTypeNotValid, ss.VPN)
}

filterChoices, err := getLocationFilterChoices(vpnServiceProvider, ss, storage)
filterChoices, err := getLocationFilterChoices(vpnServiceProvider, ss, storage, warner)
if err != nil {
return err // already wrapped error
}
Expand All @@ -111,7 +111,7 @@ func (ss *ServerSelection) validate(vpnServiceProvider string,
*ss = surfsharkRetroRegion(*ss)
}

err = validateServerFilters(*ss, filterChoices, vpnServiceProvider)
err = validateServerFilters(*ss, filterChoices, vpnServiceProvider, warner)
if err != nil {
return fmt.Errorf("for VPN service provider %s: %w", vpnServiceProvider, err)
}
Expand Down Expand Up @@ -142,19 +142,19 @@ func (ss *ServerSelection) validate(vpnServiceProvider string,
}

func getLocationFilterChoices(vpnServiceProvider string,
ss *ServerSelection, storage Storage) (filterChoices models.FilterChoices,
err error) {
ss *ServerSelection, storage Storage, warner Warner) (
filterChoices models.FilterChoices, err error) {
filterChoices = storage.GetFilterChoices(vpnServiceProvider)

if vpnServiceProvider == providers.Surfshark {
// // Retro compatibility
// TODO v4 remove
newAndRetroRegions := append(filterChoices.Regions, validation.SurfsharkRetroLocChoices()...) //nolint:gocritic
err := validate.AreAllOneOfCaseInsensitive(ss.Regions, newAndRetroRegions)
err := atLeastOneIsOneOfCaseInsensitive(ss.Regions, newAndRetroRegions, warner)
if err != nil {
// Only return error comparing with newer regions, we don't want to confuse the user
// with the retro regions in the error message.
err = validate.AreAllOneOfCaseInsensitive(ss.Regions, filterChoices.Regions)
err = atLeastOneIsOneOfCaseInsensitive(ss.Regions, filterChoices.Regions, warner)
return models.FilterChoices{}, fmt.Errorf("%w: %w", ErrRegionNotValid, err)
}
}
Expand All @@ -165,28 +165,28 @@ func getLocationFilterChoices(vpnServiceProvider string,
// validateServerFilters validates filters against the choices given as arguments.
// Set an argument to nil to pass the check for a particular filter.
func validateServerFilters(settings ServerSelection, filterChoices models.FilterChoices,
vpnServiceProvider string) (err error) {
err = validate.AreAllOneOfCaseInsensitive(settings.Countries, filterChoices.Countries)
vpnServiceProvider string, warner Warner) (err error) {
err = atLeastOneIsOneOfCaseInsensitive(settings.Countries, filterChoices.Countries, warner)
if err != nil {
return fmt.Errorf("%w: %w", ErrCountryNotValid, err)
}

err = validate.AreAllOneOfCaseInsensitive(settings.Regions, filterChoices.Regions)
err = atLeastOneIsOneOfCaseInsensitive(settings.Regions, filterChoices.Regions, warner)
if err != nil {
return fmt.Errorf("%w: %w", ErrRegionNotValid, err)
}

err = validate.AreAllOneOfCaseInsensitive(settings.Cities, filterChoices.Cities)
err = atLeastOneIsOneOfCaseInsensitive(settings.Cities, filterChoices.Cities, warner)
if err != nil {
return fmt.Errorf("%w: %w", ErrCityNotValid, err)
}

err = validate.AreAllOneOfCaseInsensitive(settings.ISPs, filterChoices.ISPs)
err = atLeastOneIsOneOfCaseInsensitive(settings.ISPs, filterChoices.ISPs, warner)
if err != nil {
return fmt.Errorf("%w: %w", ErrISPNotValid, err)
}

err = validate.AreAllOneOfCaseInsensitive(settings.Hostnames, filterChoices.Hostnames)
err = atLeastOneIsOneOfCaseInsensitive(settings.Hostnames, filterChoices.Hostnames, warner)
if err != nil {
return fmt.Errorf("%w: %w", ErrHostnameNotValid, err)
}
Expand All @@ -197,19 +197,55 @@ func validateServerFilters(settings ServerSelection, filterChoices models.Filter
// which requires a server name for TLS verification.
filterChoices.Names = settings.Names
}
err = validate.AreAllOneOfCaseInsensitive(settings.Names, filterChoices.Names)
err = atLeastOneIsOneOfCaseInsensitive(settings.Names, filterChoices.Names, warner)
if err != nil {
return fmt.Errorf("%w: %w", ErrNameNotValid, err)
}

err = validate.AreAllOneOfCaseInsensitive(settings.Categories, filterChoices.Categories)
err = atLeastOneIsOneOfCaseInsensitive(settings.Categories, filterChoices.Categories, warner)
if err != nil {
return fmt.Errorf("%w: %w", ErrCategoryNotValid, err)
}

return nil
}

func atLeastOneIsOneOfCaseInsensitive(values, choices []string,
warner Warner) (err error) {
if len(values) > 0 && len(choices) == 0 {
return fmt.Errorf("%w", validate.ErrNoChoice)
}

set := make(map[string]struct{}, len(choices))
for _, choice := range choices {
lowercaseChoice := strings.ToLower(choice)
set[lowercaseChoice] = struct{}{}
}

invalidValues := make([]string, 0, len(values))
for _, value := range values {
lowercaseValue := strings.ToLower(value)
_, ok := set[lowercaseValue]
if ok {
continue
}
invalidValues = append(invalidValues, value)
}

switch len(invalidValues) {
case 0:
return nil
case len(values):
return fmt.Errorf("%w: none of %s is one of the choices available %s",
validate.ErrValueNotOneOf, strings.Join(values, ", "), strings.Join(choices, ", "))
default:
warner.Warn(fmt.Sprintf("values %s are not in choices %s",
strings.Join(invalidValues, ", "), strings.Join(choices, ", ")))
}

return nil
}

func validateSubscriptionTierFilters(settings ServerSelection, vpnServiceProvider string) error {
switch {
case *settings.FreeOnly &&
Expand Down
9 changes: 5 additions & 4 deletions internal/configuration/settings/settings.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ type Storage interface {
// Validate validates all the settings and returns an error
// if one of them is not valid.
// TODO v4 remove pointer for receiver (because of Surfshark).
func (s *Settings) Validate(storage Storage, ipv6Supported bool) (err error) {
func (s *Settings) Validate(storage Storage, ipv6Supported bool,
warner Warner) (err error) {
nameToValidation := map[string]func() error{
"control server": s.ControlServer.validate,
"dns": s.DNS.validate,
Expand All @@ -51,7 +52,7 @@ func (s *Settings) Validate(storage Storage, ipv6Supported bool) (err error) {
"version": s.Version.validate,
// Pprof validation done in pprof constructor
"VPN": func() error {
return s.VPN.Validate(storage, ipv6Supported)
return s.VPN.Validate(storage, ipv6Supported, warner)
},
}

Expand Down Expand Up @@ -84,7 +85,7 @@ func (s *Settings) copy() (copied Settings) {
}

func (s *Settings) OverrideWith(other Settings,
storage Storage, ipv6Supported bool) (err error) {
storage Storage, ipv6Supported bool, warner Warner) (err error) {
patchedSettings := s.copy()
patchedSettings.ControlServer.overrideWith(other.ControlServer)
patchedSettings.DNS.overrideWith(other.DNS)
Expand All @@ -99,7 +100,7 @@ func (s *Settings) OverrideWith(other Settings,
patchedSettings.Version.overrideWith(other.Version)
patchedSettings.VPN.OverrideWith(other.VPN)
patchedSettings.Pprof.OverrideWith(other.Pprof)
err = patchedSettings.Validate(storage, ipv6Supported)
err = patchedSettings.Validate(storage, ipv6Supported, warner)
if err != nil {
return err
}
Expand Down
4 changes: 2 additions & 2 deletions internal/configuration/settings/vpn.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,14 @@ type VPN struct {
}

// TODO v4 remove pointer for receiver (because of Surfshark).
func (v *VPN) Validate(storage Storage, ipv6Supported bool) (err error) {
func (v *VPN) Validate(storage Storage, ipv6Supported bool, warner Warner) (err error) {
// Validate Type
validVPNTypes := []string{vpn.OpenVPN, vpn.Wireguard}
if err = validate.IsOneOf(v.Type, validVPNTypes...); err != nil {
return fmt.Errorf("%w: %w", ErrVPNTypeNotValid, err)
}

err = v.Provider.validate(v.Type, storage)
err = v.Provider.validate(v.Type, storage, warner)
if err != nil {
return fmt.Errorf("provider settings: %w", err)
}
Expand Down
2 changes: 1 addition & 1 deletion internal/server/vpn.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ func (h *vpnHandler) patchSettings(w http.ResponseWriter, r *http.Request) {

updatedSettings := h.looper.GetSettings() // already copied
updatedSettings.OverrideWith(overrideSettings)
err = updatedSettings.Validate(h.storage, h.ipv6Supported)
err = updatedSettings.Validate(h.storage, h.ipv6Supported, h.warner)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
Expand Down
Loading