Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: concurrent calls not authorized and LDAP timeout #20

Merged
merged 8 commits into from
Sep 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,13 @@ build:
go build -ldflags="-X 'github.com/criteo/data-aggregation-api/internal/version.version=$(VER)' \
-X 'github.com/criteo/data-aggregation-api/internal/version.buildUser=$(USER)' \
-X 'github.com/criteo/data-aggregation-api/internal/version.buildTime=$(BUILDDATE)'" \
-o .build/data_aggregation_api ./cmd/data_aggregation_api
-o .build/data-aggregation-api ./cmd/data-aggregation-api

run:
go run -ldflags="-X 'github.com/criteo/data-aggregation-api/internal/version.version=$(VER)' \
-X 'github.com/criteo/data-aggregation-api/internal/version.buildUser=$(USER)' \
-X 'github.com/criteo/data-aggregation-api/internal/version.buildTime=$(BUILDDATE)'" \
./cmd/data_aggregation_api
./cmd/data-aggregation-api

update_openconfig:
./update_openconfig.sh
9 changes: 9 additions & 0 deletions cmd/data-aggregation-api/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"github.com/rs/zerolog"
"github.com/rs/zerolog/log"

"github.com/criteo/data-aggregation-api/internal/api/auth"
"github.com/criteo/data-aggregation-api/internal/api/router"
"github.com/criteo/data-aggregation-api/internal/app"
"github.com/criteo/data-aggregation-api/internal/config"
Expand Down Expand Up @@ -55,6 +56,14 @@ func run() error {
log.Info().Str("build-time", date).Send()
log.Info().Str("build-user", builtBy).Send()

// Configure LDAP timeout
if config.Cfg.Authentication.LDAP != nil {
if config.Cfg.Authentication.LDAP.Timeout <= 0 {
return fmt.Errorf("LDAP timeout must be greater than 0: %d", config.Cfg.Authentication.LDAP.Timeout)
}
auth.SetLDAPDefaultTimeout(config.Cfg.Authentication.LDAP.Timeout)
}

deviceRepo := device.NewSafeRepository()
reports := report.NewRepository()

Expand Down
19 changes: 18 additions & 1 deletion internal/api/auth/basic_auth.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package auth

import (
"context"
"crypto/tls"
"errors"
"fmt"
Expand Down Expand Up @@ -29,7 +30,7 @@ type BasicAuth struct {
mode authMode
}

func NewBasicAuth(cfg config.AuthConfig) (BasicAuth, error) {
func NewBasicAuth(ctx context.Context, cfg config.AuthConfig) (BasicAuth, error) {
b := BasicAuth{mode: noAuth}

if cfg.LDAP == nil {
Expand All @@ -43,8 +44,24 @@ func NewBasicAuth(cfg config.AuthConfig) (BasicAuth, error) {
if err := b.configureLdap(ldap); err != nil {
return b, fmt.Errorf("failed to configure the request authenticator: %w", err)
}
ldap.SetMaxConnectionLifetime(cfg.LDAP.MaxConnectionLifetime)
b.mode = ldapMode

// trying a connection to LDAP to check the configuration
conn, err := ldap.connect()
if err != nil {
log.Warn().Err(err).Msg("test connection to LDAP failed")
}
if conn != nil {
if err := conn.Close(); err != nil {
log.Warn().Err(err).Msg("failed to close LDAP test connection")
}
}

if err := ldap.StartAuthenticationWorkers(ctx, cfg.LDAP.WorkersCount); err != nil {
return b, fmt.Errorf("failed to start LDAP workers: %w", err)
}

return b, nil
}

Expand Down
217 changes: 195 additions & 22 deletions internal/api/auth/ldap.go
Original file line number Diff line number Diff line change
@@ -1,41 +1,215 @@
package auth

import (
"context"
"crypto/tls"
"errors"
"fmt"
"time"

"github.com/rs/zerolog/log"

"github.com/go-ldap/ldap/v3"
)

type connectionStatus bool

var ErrLDAPTimeout = errors.New("LDAP timeout")

const (
connectionClosed connectionStatus = false
connectionUp connectionStatus = true
)

type authRequest struct {
authResp chan bool
username string
password string
}

type result struct {
auth bool
conn connectionStatus
}

type LDAPAuth struct {
ldapClient *ldap.Conn
bindDN string
password string
baseDN string
tlsConfig *tls.Config
reqCh chan authRequest
ldapURL string
bindDN string
password string
baseDN string
maxConnectionLifetime time.Duration
}

func NewLDAPAuth(ldapURL string, bindDN string, password string, baseDN string, tlsConfig *tls.Config) *LDAPAuth {
conn, err := ldap.DialURL(ldapURL, ldap.DialWithTLSConfig(tlsConfig))
if err != nil {
log.Error().Err(err).Str("ldapURL", ldapURL).Msg("failed to connect to the LDAP server")
return nil
}
return &LDAPAuth{
ldapClient: conn,
bindDN: bindDN,
password: password,
baseDN: baseDN,
tlsConfig: tlsConfig,
ldapURL: ldapURL,
bindDN: bindDN,
password: password,
baseDN: baseDN,
reqCh: make(chan authRequest),
}
}

func SetLDAPDefaultTimeout(timeout time.Duration) {
ldap.DefaultTimeout = timeout //nolint:reassign // we want to customize the default timeout
}

// SetMaxConnectionLifetime sets the maximum lifetime of a connection.
//
// The maximum lifetime is the maximum amount of time a connection may be reused for.
// This is not a guarantee, as the connection may have been closed by the server before reaching that timer.
func (l *LDAPAuth) SetMaxConnectionLifetime(maxConnectionLifetime time.Duration) {
l.maxConnectionLifetime = maxConnectionLifetime
}

// StartAuthenticationWorkers starts a pool of workers that will handle the authentication requests.
func (l *LDAPAuth) StartAuthenticationWorkers(ctx context.Context, workersCount int) error {
if workersCount <= 0 {
return fmt.Errorf("'WorkersCount' must be greater than 0: %d", workersCount)
}
for i := 0; i < workersCount; i++ {
go l.spawnConnectionWorker(ctx)
}
return nil
}
func (l *LDAPAuth) AuthenticateUser(username string, password string) bool {
if err := l.ldapClient.Bind(l.bindDN, l.password); err != nil {
req := authRequest{
username: username,
password: password,
authResp: make(chan bool),
}
l.reqCh <- req
return <-req.authResp
}

func (l *LDAPAuth) spawnConnectionWorker(ctx context.Context) {
const maxAttempts = 3
var conn *ldap.Conn
var err error
tick := time.NewTicker(l.maxConnectionLifetime)

for {
select {
case req := <-l.reqCh:
auth := false
attempt := 1
for attempt <= maxAttempts {
attempt++
log.Debug().Msgf("worker LDAP authentication attempt number %d", attempt)
// (re)connect if needed
if conn == nil || conn.IsClosing() {
log.Debug().Msg("LDAP connection is closed, reconnecting")
conn, err = l.connect()
if err != nil {
log.Error().Err(err).Msg("worker LDAP reconnection failed")
req.authResp <- false
break
}
}

// bind with the user credentials
var connState connectionStatus
func() {
// this anonymous function ensures the context is released as soon as possible (because of the for loop)
ctxTimeout, cancel := context.WithTimeoutCause(ctx, ldap.DefaultTimeout, ErrLDAPTimeout)
defer cancel()
auth, connState = l.authenticateWithTimeout(ctxTimeout, conn, req.username, req.password)
}()

if connState == connectionClosed {
log.Debug().Msg("LDAP connection was closed by the server, closing on client side")
if err := conn.Close(); err != nil {
log.Error().Err(err).Msg("connection was closed by the server but failed to close on client side")
}
} else {
// LDAP connection is still up, we accept the authentication result
log.Debug().Msg("auth response valid")
break
}
}

log.Debug().Msgf("worker LDAP authentication attempt number %d, result: %t", attempt, auth)

req.authResp <- auth
tick.Reset(l.maxConnectionLifetime)

case <-tick.C:
// close connection if no request has been made
log.Debug().Msg("timer reached, closing connection")
closeLDAPConnection(conn)

case <-ctx.Done():
// gracefully close connection if context is done
log.Debug().Msg("context is closed, closing connection")
closeLDAPConnection(conn)
return
}
}
}

func closeLDAPConnection(conn *ldap.Conn) {
if conn != nil && !conn.IsClosing() {
if err := conn.Close(); err != nil {
log.Error().Err(err).Msg("unable to close the LDAP connection")
}
}
}

// authenticateWithTimeout performs the authentication against LDAP with a timeout.
func (l *LDAPAuth) authenticateWithTimeout(ctx context.Context, conn *ldap.Conn, username, password string) (bool, connectionStatus) {
// request the authentication
res := make(chan result, 1)
go func() {
a, c := l.authenticate(conn, username, password)
res <- result{auth: a, conn: c}
}()

var connState connectionStatus
var auth bool

// handle timeout and context closing
select {
case r := <-res:
auth = r.auth
connState = r.conn
case <-ctx.Done():
auth = false
if errors.Is(context.Cause(ctx), ErrLDAPTimeout) {
log.Error().Msg("LDAP authentication timeout")
closeLDAPConnection(conn)
} else {
connState = connectionClosed
}
}
return auth, connState
}

func (l *LDAPAuth) connect() (*ldap.Conn, error) {
conn, err := ldap.DialURL(l.ldapURL, ldap.DialWithTLSConfig(l.tlsConfig))
if err != nil {
return nil, fmt.Errorf("failed to dial to the LDAP server: %w", err)
}

return conn, nil
}

// authenticate performs the authentication against LDAP.
// The first returned boolean is true if the authentication is successful, false otherwise.
// The second returned boolean is false if the connection is closed, true otherwise.
func (l *LDAPAuth) authenticate(conn *ldap.Conn, username string, password string) (bool, connectionStatus) {
if err := conn.Bind(l.bindDN, l.password); err != nil {
log.Error().Err(err).Str("bindDN", l.bindDN).Msg("failed to bind to LDAP")
return false

// detect TCP connection closed or any network errors
if ldap.IsErrorWithCode(err, ldap.ErrorNetwork) {
return false, connectionClosed
}
return false, connectionUp
}
search, err := l.ldapClient.Search(ldap.NewSearchRequest(
search, err := conn.Search(ldap.NewSearchRequest(
l.baseDN,
ldap.ScopeWholeSubtree,
ldap.NeverDerefAliases,
Expand All @@ -50,16 +224,15 @@ func (l *LDAPAuth) AuthenticateUser(username string, password string) bool {
const userKey = "user"
if err != nil {
log.Error().Err(err).Str(userKey, username).Msg("failed to perform LDAP search to find user")
return false
return false, connectionUp
}
if len(search.Entries) != 1 {
log.Error().Str(userKey, username).Msg("no result or more than 1 result found for user")
return false
return false, connectionUp
}
if err := l.ldapClient.Bind(search.Entries[0].DN, password); err != nil {
if err := conn.Bind(search.Entries[0].DN, password); err != nil {
log.Error().Err(err).Str(userKey, username).Msg("failed to bind with user")
return false
return false, connectionUp
}
log.Debug().Str(userKey, username).Msg("succesfully authenticated user")
return true
return true, connectionUp
}
12 changes: 9 additions & 3 deletions internal/api/router/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"fmt"
"net/http"
"time"

"github.com/prometheus/client_golang/prometheus/promhttp"
"github.com/rs/zerolog/log"
Expand All @@ -15,6 +16,8 @@ import (
"github.com/julienschmidt/httprouter"
)

const shutdownTimeout = 5 * time.Second

type DevicesRepository interface {
Set(devices map[string]*device.Device)
ListAFKEnabledDevicesJSON() ([]byte, error)
Expand All @@ -39,7 +42,7 @@ func (m *Manager) ListenAndServe(ctx context.Context, address string, port int)
log.Warn().Msg("Shutdown.")
}()

withAuth, err := auth.NewBasicAuth(config.Cfg.Authentication)
withAuth, err := auth.NewBasicAuth(ctx, config.Cfg.Authentication)
if err != nil {
return err
}
Expand All @@ -63,12 +66,15 @@ func (m *Manager) ListenAndServe(ctx context.Context, address string, port int)
// TODO: handle http failure! with a channel
go func() {
if err := httpServer.ListenAndServe(); err != nil {
log.Error().Err(err).Send()
log.Error().Err(err).Msg("stopped to listen and serve")
}
}()

<-ctx.Done()
if err := httpServer.Shutdown(context.Background()); err != nil {
ctxCancel, cancel := context.WithTimeout(context.Background(), shutdownTimeout)
defer cancel()

if err := httpServer.Shutdown(ctxCancel); err != nil {
log.Error().Err(err).Send()
}

Expand Down
Loading