Skip to content

Commit

Permalink
Pull request: 2471 defer
Browse files Browse the repository at this point in the history
Merge in DNS/adguard-home from 2471-defer to master

Updates #2471.

* commit '1c754788f9139ed9741cf01c6d94bcced6909b8c':
  home: improve getCurrentUser
  home: improve checkSession
  Use a couple of defer in internal/home/auth.go
  • Loading branch information
ainar-g committed Dec 23, 2020
2 parents fc79e2e + 1c75478 commit e829e7a
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 34 deletions.
64 changes: 38 additions & 26 deletions internal/home/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,10 @@ func (s *session) deserialize(data []byte) bool {
// Auth - global object
type Auth struct {
db *bbolt.DB
sessions map[string]*session // session name -> session data
lock sync.Mutex
sessions map[string]*session
users []User
sessionTTL uint32 // in seconds
lock sync.Mutex
sessionTTL uint32
}

// User object
Expand Down Expand Up @@ -223,24 +223,35 @@ func (a *Auth) removeSession(sess []byte) {
log.Debug("Auth: removed session from DB")
}

// CheckSession - check if session is valid
// Return 0 if OK; -1 if session doesn't exist; 1 if session has expired
func (a *Auth) CheckSession(sess string) int {
// checkSessionResult is the result of checking a session.
type checkSessionResult int

// checkSessionResult constants.
const (
checkSessionOK checkSessionResult = 0
checkSessionNotFound checkSessionResult = -1
checkSessionExpired checkSessionResult = 1
)

// checkSession checks if the session is valid.
func (a *Auth) checkSession(sess string) (res checkSessionResult) {
now := uint32(time.Now().UTC().Unix())
update := false

a.lock.Lock()
defer a.lock.Unlock()

s, ok := a.sessions[sess]
if !ok {
a.lock.Unlock()
return -1
return checkSessionNotFound
}

if s.expire <= now {
delete(a.sessions, sess)
key, _ := hex.DecodeString(sess)
a.removeSession(key)
a.lock.Unlock()
return 1

return checkSessionExpired
}

newExpire := now + a.sessionTTL
Expand All @@ -250,16 +261,14 @@ func (a *Auth) CheckSession(sess string) int {
s.expire = newExpire
}

a.lock.Unlock()

if update {
key, _ := hex.DecodeString(sess)
if a.storeSession(key, s) {
log.Debug("Auth: updated session %s: expire=%d", sess, s.expire)
}
}

return 0
return checkSessionOK
}

// RemoveSession - remove session
Expand Down Expand Up @@ -392,8 +401,8 @@ func optionalAuthThird(w http.ResponseWriter, r *http.Request) (authFirst bool)
ok = true

} else if err == nil {
r := Context.auth.CheckSession(cookie.Value)
if r == 0 {
r := Context.auth.checkSession(cookie.Value)
if r == checkSessionOK {
ok = true
} else if r < 0 {
log.Debug("Auth: invalid cookie value: %s", cookie)
Expand Down Expand Up @@ -434,12 +443,13 @@ func optionalAuth(handler func(http.ResponseWriter, *http.Request)) func(http.Re
authRequired := Context.auth != nil && Context.auth.AuthRequired()
cookie, err := r.Cookie(sessionCookieName)
if authRequired && err == nil {
r := Context.auth.CheckSession(cookie.Value)
if r == 0 {
r := Context.auth.checkSession(cookie.Value)
if r == checkSessionOK {
w.Header().Set("Location", "/")
w.WriteHeader(http.StatusFound)

return
} else if r < 0 {
} else if r == checkSessionNotFound {
log.Debug("Auth: invalid cookie value: %s", cookie)
}
}
Expand Down Expand Up @@ -503,32 +513,34 @@ func (a *Auth) UserFind(login, password string) User {
return User{}
}

// GetCurrentUser - get the current user
func (a *Auth) GetCurrentUser(r *http.Request) User {
// getCurrentUser returns the current user. It returns an empty User if the
// user is not found.
func (a *Auth) getCurrentUser(r *http.Request) User {
cookie, err := r.Cookie(sessionCookieName)
if err != nil {
// there's no Cookie, check Basic authentication
// There's no Cookie, check Basic authentication.
user, pass, ok := r.BasicAuth()
if ok {
u := Context.auth.UserFind(user, pass)
return u
return Context.auth.UserFind(user, pass)
}

return User{}
}

a.lock.Lock()
defer a.lock.Unlock()

s, ok := a.sessions[cookie.Value]
if !ok {
a.lock.Unlock()
return User{}
}

for _, u := range a.users {
if u.Name == s.userName {
a.lock.Unlock()
return u
}
}
a.lock.Unlock()

return User{}
}

Expand Down
14 changes: 7 additions & 7 deletions internal/home/auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ func TestAuth(t *testing.T) {
user := User{Name: "name"}
a.UserAdd(&user, "password")

assert.True(t, a.CheckSession("notfound") == -1)
assert.Equal(t, checkSessionNotFound, a.checkSession("notfound"))
a.RemoveSession("notfound")

sess, err := getSession(&users[0])
Expand All @@ -49,22 +49,22 @@ func TestAuth(t *testing.T) {
// check expiration
s.expire = uint32(now)
a.addSession(sess, &s)
assert.True(t, a.CheckSession(sessStr) == 1)
assert.Equal(t, checkSessionExpired, a.checkSession(sessStr))

// add session with TTL = 2 sec
s = session{}
s.expire = uint32(time.Now().UTC().Unix() + 2)
a.addSession(sess, &s)
assert.True(t, a.CheckSession(sessStr) == 0)
assert.Equal(t, checkSessionOK, a.checkSession(sessStr))

a.Close()

// load saved session
a = InitAuth(fn, users, 60)

// the session is still alive
assert.True(t, a.CheckSession(sessStr) == 0)
// reset our expiration time because CheckSession() has just updated it
assert.Equal(t, checkSessionOK, a.checkSession(sessStr))
// reset our expiration time because checkSession() has just updated it
s.expire = uint32(time.Now().UTC().Unix() + 2)
a.storeSession(sess, &s)
a.Close()
Expand All @@ -76,7 +76,7 @@ func TestAuth(t *testing.T) {

// load and remove expired sessions
a = InitAuth(fn, users, 60)
assert.True(t, a.CheckSession(sessStr) == -1)
assert.Equal(t, checkSessionNotFound, a.checkSession(sessStr))

a.Close()
os.Remove(fn)
Expand Down Expand Up @@ -111,7 +111,7 @@ func TestAuthHTTP(t *testing.T) {
Context.auth = InitAuth(fn, users, 60)

handlerCalled := false
handler := func(w http.ResponseWriter, r *http.Request) {
handler := func(_ http.ResponseWriter, _ *http.Request) {
handlerCalled = true
}
handler2 := optionalAuth(handler)
Expand Down
2 changes: 1 addition & 1 deletion internal/home/control.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ type profileJSON struct {

func handleGetProfile(w http.ResponseWriter, r *http.Request) {
pj := profileJSON{}
u := Context.auth.GetCurrentUser(r)
u := Context.auth.getCurrentUser(r)
pj.Name = u.Name

data, err := json.Marshal(pj)
Expand Down

0 comments on commit e829e7a

Please sign in to comment.