diff --git a/internal/home/auth.go b/internal/home/auth.go index 7ce55ff9a0f..3a194f8274e 100644 --- a/internal/home/auth.go +++ b/internal/home/auth.go @@ -62,7 +62,7 @@ func (s *session) deserialize(data []byte) bool { // Auth - global object type Auth struct { db *bbolt.DB - blocker *authBlocker + blocker *authRateLimiter sessions map[string]*session users []User lock sync.Mutex @@ -76,7 +76,7 @@ type User struct { } // InitAuth - create a global object -func InitAuth(dbFilename string, users []User, sessionTTL uint32, blocker *authBlocker) *Auth { +func InitAuth(dbFilename string, users []User, sessionTTL uint32, blocker *authRateLimiter) *Auth { log.Info("Initializing auth module: %s", dbFilename) a := &Auth{ @@ -333,19 +333,19 @@ func cookieExpiryFormat(exp time.Time) (formatted string) { return exp.Format(cookieTimeFormat) } -func (a *Auth) httpCookie(req loginJSON, ip net.IP) (cookie string, err error) { +func (a *Auth) httpCookie(req loginJSON, addr string) (cookie string, err error) { blocker := a.blocker u := a.UserFind(req.Name, req.Password) if len(u.Name) == 0 { if blocker != nil { - blocker.inc(ip) + blocker.inc(addr) } return "", err } if blocker != nil { - blocker.remove(ip) + blocker.remove(addr) } var sess []byte @@ -417,23 +417,31 @@ func handleLogin(w http.ResponseWriter, r *http.Request) { err := json.NewDecoder(r.Body).Decode(&req) if err != nil { httpError(w, http.StatusBadRequest, "json decode: %s", err) + return } - var ip net.IP - ip, err = realIP(r) - if err != nil { - log.Info("auth: getting real ip from request: %s", err) + var remoteAddr string + // The realIP couldn't be used here due to security issues. + // + // See https://github.com/AdguardTeam/AdGuardHome/issues/2799. + // + // TODO(e.burkov): Use realIP when the issue will be fixed. + if remoteAddr, err = aghnet.SplitHost(r.RemoteAddr); err == nil { + httpError(w, http.StatusBadRequest, "auth: getting remote address: %s", err) + + return } if blocker := Context.auth.blocker; blocker != nil { - left := blocker.check(ip) + left := blocker.check(remoteAddr) if left > 0 { w.Header().Set("Retry-After", fmt.Sprint(int64(left.Seconds()))) httpError( w, http.StatusTooManyRequests, - "auth: out of login attempts", + "auth: blocked for %d minute(s)", + 0, ) return @@ -441,7 +449,7 @@ func handleLogin(w http.ResponseWriter, r *http.Request) { } var cookie string - cookie, err = Context.auth.httpCookie(req, ip) + cookie, err = Context.auth.httpCookie(req, remoteAddr) if err != nil { httpError(w, http.StatusBadRequest, "crypto rand reader: %s", err) @@ -449,7 +457,11 @@ func handleLogin(w http.ResponseWriter, r *http.Request) { } if len(cookie) == 0 { - if ip == nil { + var ip net.IP + ip, err = realIP(r) + if err != nil { + log.Info("auth: getting real ip from request: %s", err) + } else if ip == nil { // Technically shouldn't happen. log.Info("auth: failed to login user %q from unknown ip", req.Name) } else { diff --git a/internal/home/auth_test.go b/internal/home/auth_test.go index bda36b31795..f6b2b10639f 100644 --- a/internal/home/auth_test.go +++ b/internal/home/auth_test.go @@ -151,7 +151,7 @@ func TestAuthHTTP(t *testing.T) { assert.True(t, handlerCalled) // perform login - cookie, err := Context.auth.httpCookie(loginJSON{Name: "name", Password: "password"}, nil) + cookie, err := Context.auth.httpCookie(loginJSON{Name: "name", Password: "password"}, "") assert.Nil(t, err) assert.NotEmpty(t, cookie) diff --git a/internal/home/authblocker.go b/internal/home/authblocker.go deleted file mode 100644 index 8e6982a9f4b..00000000000 --- a/internal/home/authblocker.go +++ /dev/null @@ -1,138 +0,0 @@ -package home - -import ( - "net" - "sync" - "time" - - "github.com/AdguardTeam/AdGuardHome/internal/agherr" -) - -const ( - // flushAfter defines for how long will attempter be tracked. - flushAfter = 1 * time.Minute - // flusherPeriod determines the frequency of flusher loop work. - flusherPeriod = 1 * time.Second -) - -// attempter is an entry of authBlocker's cache. -type attempter struct { - until time.Time - num uint16 -} - -// authBlocker used to cache unsuccessful authentication attempts. -type authBlocker struct { - attempters map[string]attempter - blockDur time.Duration - attLock sync.RWMutex - maxAttempts uint16 -} - -// flushLocked performs all the dirty work. For internal use only. -func (ab *authBlocker) flushLocked(now time.Time) { - withDelta := now.Add(-flushAfter) - for k, v := range ab.attempters { - if withDelta.After(v.until) { - delete(ab.attempters, k) - } - } -} - -// flush stops tracking of attempters tracking period of which have ended. -func (ab *authBlocker) flush() { - now := time.Now() - - ab.attLock.Lock() - defer ab.attLock.Unlock() - - ab.flushLocked(now) -} - -// flusher is a loop with flush method for concurrent usage. -func (ab *authBlocker) flusher(tk <-chan time.Time) { - defer agherr.LogPanic("authblocker: flusher loop") - - for range tk { - ab.flush() - } -} - -// newAuthBlocker returns properly initialized *authBlocker. -func newAuthBlocker(blockDur time.Duration, maxAttempts uint16) (ab *authBlocker) { - ab = &authBlocker{ - attempters: make(map[string]attempter), - blockDur: blockDur, - maxAttempts: maxAttempts, - } - - tk := time.NewTicker(flusherPeriod) - go ab.flusher(tk.C) - - return ab -} - -// checkLocked performs all the dirty work. For internal use only. -func (ab *authBlocker) checkLocked(ip net.IP, now time.Time) (left time.Duration) { - a, ok := ab.attempters[string(ip)] - if !ok { - return 0 - } - - if a.num < ab.maxAttempts { - return 0 - } - - return a.until.Sub(now) -} - -// check returns the time left until unblocking. The nonpositive result should -// be interpreted as not blocked attempter. -func (ab *authBlocker) check(ip net.IP) (left time.Duration) { - now := time.Now() - - ab.attLock.RLock() - defer ab.attLock.RUnlock() - - return ab.checkLocked(ip, now) -} - -// incLocked performs all the dirty work. For internal use only. -func (ab *authBlocker) incLocked(ip net.IP, now time.Time) { - defer agherr.LogPanic("authblocker") - - var until time.Time = now - var attempts uint16 = 1 - - id := string(ip) - a, ok := ab.attempters[id] - if ok { - attempts = a.num + 1 - } - if attempts >= ab.maxAttempts { - until = until.Add(ab.blockDur) - } - - ab.attempters[id] = attempter{ - num: attempts, - until: until, - } -} - -// inc updates the tracked attempter. -func (ab *authBlocker) inc(ip net.IP) { - now := time.Now() - - ab.attLock.Lock() - defer ab.attLock.Unlock() - - ab.incLocked(ip, now) -} - -// remove stops any tracking and any blocking of an attempter. -func (ab *authBlocker) remove(ip net.IP) { - ab.attLock.Lock() - defer ab.attLock.Unlock() - - delete(ab.attempters, string(ip)) -} diff --git a/internal/home/authratelimiter.go b/internal/home/authratelimiter.go new file mode 100644 index 00000000000..0abf3af251b --- /dev/null +++ b/internal/home/authratelimiter.go @@ -0,0 +1,119 @@ +package home + +import ( + "sync" + "time" +) + +// flushAfter defines for how long will attempter be tracked. +const flushAfter = 1 * time.Minute + +// attempter is an entry of authRateLimiter's cache. +type attempter struct { + until time.Time + num uint16 +} + +// authRateLimiter used to cache unsuccessful authentication attempts. +type authRateLimiter struct { + attempters map[string]attempter + blockDur time.Duration + attemptersLock sync.RWMutex + maxAttempts uint16 +} + +// newAuthRateLimiter returns properly initialized *authRateLimiter. +func newAuthRateLimiter(blockDur time.Duration, maxAttempts uint16) (ab *authRateLimiter) { + return &authRateLimiter{ + attempters: make(map[string]attempter), + blockDur: blockDur, + maxAttempts: maxAttempts, + } +} + +// flushLocked checks each tracked attempter removing expired ones. For +// internal use only. +func (ab *authRateLimiter) flushLocked(now time.Time) { + for k, v := range ab.attempters { + if now.After(v.until) { + delete(ab.attempters, k) + } + } +} + +// flush stops tracking of attempters tracking period of which have ended. +func (ab *authRateLimiter) flush() { + now := time.Now() + + ab.attemptersLock.Lock() + defer ab.attemptersLock.Unlock() + + ab.flushLocked(now) +} + +// checkLocked checks the attempter for it's state. For internal use only. +func (ab *authRateLimiter) checkLocked(attID string, now time.Time) (left time.Duration) { + a, ok := ab.attempters[attID] + if !ok { + return 0 + } + + if a.num < ab.maxAttempts { + return 0 + } + + return a.until.Sub(now) +} + +// check returns the time left until unblocking. The nonpositive result should +// be interpreted as not blocked attempter. +func (ab *authRateLimiter) check(attID string) (left time.Duration) { + now := time.Now() + ab.flush() + + ab.attemptersLock.RLock() + defer ab.attemptersLock.RUnlock() + + return ab.checkLocked(attID, now) +} + +// incLocked increments the number of unsuccessful attempts for attempter with +// ip and updates it's blocking moment if needed. For internal use only. +func (ab *authRateLimiter) incLocked(attID string, now time.Time) { + var until time.Time = now.Add(flushAfter) + var attempts uint16 = 1 + + a, ok := ab.attempters[attID] + if ok { + // The attempter will be tracked during at least 1 minute since + // very first unsuccessful attempt but not since each one. + until = a.until + attempts = a.num + 1 + } + if attempts >= ab.maxAttempts { + until = now.Add(ab.blockDur) + } + + ab.attempters[attID] = attempter{ + num: attempts, + until: until, + } +} + +// inc updates the tracked attempter. +func (ab *authRateLimiter) inc(attID string) { + now := time.Now() + + ab.attemptersLock.Lock() + defer ab.attemptersLock.Unlock() + + ab.incLocked(attID, now) +} + +// remove stops any tracking and any blocking of an attempter. +func (ab *authRateLimiter) remove(attID string) { + ab.attemptersLock.Lock() + defer ab.attemptersLock.Unlock() + + delete(ab.attempters, attID) +} diff --git a/internal/home/authblocker_test.go b/internal/home/authratelimiter_test.go similarity index 82% rename from internal/home/authblocker_test.go rename to internal/home/authratelimiter_test.go index 5a400235c90..35d0aa2ff18 100644 --- a/internal/home/authblocker_test.go +++ b/internal/home/authratelimiter_test.go @@ -9,8 +9,8 @@ import ( "github.com/stretchr/testify/require" ) -func TestAuthBlocker_Flush(t *testing.T) { - const key = "127.0.0.1" +func TestAuthRateLimiter_Flush(t *testing.T) { + const key = "some-key" testCases := []struct { name string @@ -25,7 +25,7 @@ func TestAuthBlocker_Flush(t *testing.T) { }, { name: "nope_yet", att: attempter{ - until: time.Now(), + until: time.Now().Add(flushAfter / 2), }, wantExp: false, }, { @@ -37,7 +37,7 @@ func TestAuthBlocker_Flush(t *testing.T) { }} for _, tc := range testCases { - ab := &authBlocker{ + ab := &authRateLimiter{ attempters: map[string]attempter{ key: tc.att, }, @@ -58,7 +58,7 @@ func TestAuthBlocker_Flush(t *testing.T) { } } -func TestAuthBlocker_Check(t *testing.T) { +func TestAuthRateLimiter_Check(t *testing.T) { key := string(net.IP{127, 0, 0, 1}) const maxAtt = 1 now := time.Now() @@ -73,6 +73,11 @@ func TestAuthBlocker_Check(t *testing.T) { name: "expired", num: 0, wantExp: true, + }, { + until: now.Add(flushAfter), + name: "not_blocked_but_tracked", + num: 0, + wantExp: true, }, { until: now, name: "expired_but_stayed", @@ -92,12 +97,12 @@ func TestAuthBlocker_Check(t *testing.T) { until: tc.until, }, } - ab := &authBlocker{ + ab := &authRateLimiter{ maxAttempts: maxAtt, attempters: attempters, } t.Run(tc.name, func(t *testing.T) { - until := ab.check(net.IP{127, 0, 0, 1}) + until := ab.check(key) if tc.wantExp { assert.LessOrEqual(t, until, time.Duration(0)) @@ -108,19 +113,19 @@ func TestAuthBlocker_Check(t *testing.T) { } t.Run("non-existent", func(t *testing.T) { - ab := &authBlocker{ + ab := &authRateLimiter{ attempters: map[string]attempter{ key + "smthng": {}, }, } - until := ab.check(net.IP{127, 0, 0, 1}) + until := ab.check(key) assert.Zero(t, until) }) } -func TestAuthBlocker_Inc(t *testing.T) { +func TestAuthRateLimiter_Inc(t *testing.T) { ip := net.IP{127, 0, 0, 1} key := string(ip) now := time.Now() @@ -154,13 +159,13 @@ func TestAuthBlocker_Inc(t *testing.T) { until: tc.until, }, } - ab := &authBlocker{ + ab := &authRateLimiter{ blockDur: blockDur, maxAttempts: maxAtt, attempters: attempters, } t.Run(tc.name, func(t *testing.T) { - ab.inc(ip) + ab.inc(key) a, ok := ab.attempters[key] require.True(t, ok) @@ -171,13 +176,13 @@ func TestAuthBlocker_Inc(t *testing.T) { } t.Run("non-existent", func(t *testing.T) { - ab := &authBlocker{ + ab := &authRateLimiter{ blockDur: blockDur, maxAttempts: maxAtt, attempters: map[string]attempter{}, } - ab.inc(ip) + ab.inc(key) a, ok := ab.attempters[key] require.True(t, ok) @@ -185,17 +190,17 @@ func TestAuthBlocker_Inc(t *testing.T) { }) } -func TestAuthBlocker_Remove(t *testing.T) { - ip := net.IP{127, 0, 0, 1} +func TestAuthRateLimiter_Remove(t *testing.T) { + const key = "some-key" attempters := map[string]attempter{ - string(ip): {}, + key: {}, } - ab := &authBlocker{ + ab := &authRateLimiter{ attempters: attempters, } - ab.remove(ip) + ab.remove(key) assert.Empty(t, ab.attempters) } diff --git a/internal/home/home.go b/internal/home/home.go index b83cbadfca6..3cbcba7ca0d 100644 --- a/internal/home/home.go +++ b/internal/home/home.go @@ -282,18 +282,20 @@ func run(args options) { sessFilename := filepath.Join(Context.getDataDir(), "sessions.db") GLMode = args.glinetMode - var authBlocker *authBlocker - if config.AuthAttempts > 0 { - authBlocker = newAuthBlocker( + var arl *authRateLimiter + if config.AuthAttempts > 0 && config.AuthBlockMin > 0 { + arl = newAuthRateLimiter( time.Duration(config.AuthBlockMin)*time.Minute, config.AuthAttempts, ) + } else { + log.Info("the authratelimiter is disabled") } Context.auth = InitAuth( sessFilename, config.Users, config.WebSessionTTLHours*60*60, - authBlocker, + arl, ) if Context.auth == nil { log.Fatalf("Couldn't initialize Auth module")