Skip to content

Commit

Permalink
Support multiple LDAP servers in a auth source (#6898)
Browse files Browse the repository at this point in the history
Signed-off-by: abhishek818 <abhishekguptaatweb17@gmail.com>
  • Loading branch information
abhishek818 committed Jul 17, 2024
1 parent de1a550 commit 75ee729
Show file tree
Hide file tree
Showing 5 changed files with 49 additions and 29 deletions.
2 changes: 1 addition & 1 deletion cmd/admin_auth_ldap.go
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ func parseLdapConfig(c *cli.Context, config *ldap.Source) error {
config.Name = c.String("name")
}
if c.IsSet("host") {
config.Host = c.String("host")
config.HostList = c.String("hostlist")
}
if c.IsSet("port") {
config.Port = c.Int("port")
Expand Down
16 changes: 8 additions & 8 deletions cmd/admin_auth_ldap_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ func TestAddLdapBindDn(t *testing.T) {
IsSyncEnabled: true,
Cfg: &ldap.Source{
Name: "ldap (via Bind DN) source full",
Host: "ldap-bind-server full",
HostList: "ldap-bind-server full",
Port: 9876,
SecurityProtocol: ldap.SecurityProtocol(1),
SkipVerify: true,
Expand Down Expand Up @@ -99,7 +99,7 @@ func TestAddLdapBindDn(t *testing.T) {
IsActive: true,
Cfg: &ldap.Source{
Name: "ldap (via Bind DN) source min",
Host: "ldap-bind-server min",
HostList: "ldap-bind-server min",
Port: 1234,
SecurityProtocol: ldap.SecurityProtocol(0),
UserBase: "ou=Users,dc=min-domain-bind,dc=org",
Expand Down Expand Up @@ -280,7 +280,7 @@ func TestAddLdapSimpleAuth(t *testing.T) {
IsActive: false,
Cfg: &ldap.Source{
Name: "ldap (simple auth) source full",
Host: "ldap-simple-server full",
HostList: "ldap-simple-server full",
Port: 987,
SecurityProtocol: ldap.SecurityProtocol(2),
SkipVerify: true,
Expand Down Expand Up @@ -317,7 +317,7 @@ func TestAddLdapSimpleAuth(t *testing.T) {
IsActive: true,
Cfg: &ldap.Source{
Name: "ldap (simple auth) source min",
Host: "ldap-simple-server min",
HostList: "ldap-simple-server min",
Port: 123,
SecurityProtocol: ldap.SecurityProtocol(0),
UserDN: "cn=%s,ou=Users,dc=min-domain-simple,dc=org",
Expand Down Expand Up @@ -526,7 +526,7 @@ func TestUpdateLdapBindDn(t *testing.T) {
IsSyncEnabled: true,
Cfg: &ldap.Source{
Name: "ldap (via Bind DN) source full",
Host: "ldap-bind-server full",
HostList: "ldap-bind-server full",
Port: 9876,
SecurityProtocol: ldap.SecurityProtocol(1),
SkipVerify: true,
Expand Down Expand Up @@ -630,7 +630,7 @@ func TestUpdateLdapBindDn(t *testing.T) {
authSource: &auth.Source{
Type: auth.LDAP,
Cfg: &ldap.Source{
Host: "ldap-server",
HostList: "ldap-server",
},
},
},
Expand Down Expand Up @@ -978,7 +978,7 @@ func TestUpdateLdapSimpleAuth(t *testing.T) {
IsActive: false,
Cfg: &ldap.Source{
Name: "ldap (simple auth) source full",
Host: "ldap-simple-server full",
HostList: "ldap-simple-server full",
Port: 987,
SecurityProtocol: ldap.SecurityProtocol(2),
SkipVerify: true,
Expand Down Expand Up @@ -1078,7 +1078,7 @@ func TestUpdateLdapSimpleAuth(t *testing.T) {
authSource: &auth.Source{
Type: auth.DLDAP,
Cfg: &ldap.Source{
Host: "ldap-server",
HostList: "ldap-server",
},
},
},
Expand Down
2 changes: 1 addition & 1 deletion routers/web/admin/auths.go
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ func parseLDAPConfig(form forms.AuthenticationForm) *ldap.Source {
}
return &ldap.Source{
Name: form.Name,
Host: form.Host,
HostList: form.Host,
Port: form.Port,
SecurityProtocol: ldap.SecurityProtocol(form.SecurityProtocol),
SkipVerify: form.SkipVerify,
Expand Down
2 changes: 1 addition & 1 deletion services/auth/source/ldap/source.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import (
// Source Basic LDAP authentication service
type Source struct {
Name string // canonical name (ie. corporate.ad)
Host string // LDAP host
HostList string // list containing LDAP host(s)
Port int // port number
SecurityProtocol SecurityProtocol
SkipVerify bool
Expand Down
56 changes: 38 additions & 18 deletions services/auth/source/ldap/source_search.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"net"
"strconv"
"strings"
"time"

"code.gitea.io/gitea/modules/container"
"code.gitea.io/gitea/modules/log"
Expand Down Expand Up @@ -111,28 +112,47 @@ func (source *Source) findUserDN(l *ldap.Conn, name string) (string, bool) {
func dial(source *Source) (*ldap.Conn, error) {
log.Trace("Dialing LDAP with security protocol (%v) without verifying: %v", source.SecurityProtocol, source.SkipVerify)

tlsConfig := &tls.Config{
ServerName: source.Host,
InsecureSkipVerify: source.SkipVerify,
}
ldap.DefaultTimeout = time.Second * 15
// HostList is a list of hosts separated by commas
hostList := strings.Split(source.HostList, ",")

if source.SecurityProtocol == SecurityProtocolLDAPS {
return ldap.DialTLS("tcp", net.JoinHostPort(source.Host, strconv.Itoa(source.Port)), tlsConfig)
}
for _, host := range hostList {
tlsConfig := &tls.Config{
ServerName: host,
InsecureSkipVerify: source.SkipVerify,
}

conn, err := ldap.Dial("tcp", net.JoinHostPort(source.Host, strconv.Itoa(source.Port)))
if err != nil {
return nil, fmt.Errorf("error during Dial: %w", err)
}
if source.SecurityProtocol == SecurityProtocolLDAPS {
conn, err := ldap.DialTLS("tcp", net.JoinHostPort(host, strconv.Itoa(source.Port)), tlsConfig)

if err != nil {
// Connection failed, try again with the next host.
log.Trace("error during Dial for host %s: %w", host, err)
continue
}
conn.SetTimeout(time.Second * 10)

if source.SecurityProtocol == SecurityProtocolStartTLS {
if err = conn.StartTLS(tlsConfig); err != nil {
conn.Close()
return nil, fmt.Errorf("error during StartTLS: %w", err)
return conn, err
}

conn, err := ldap.Dial("tcp", net.JoinHostPort(host, strconv.Itoa(source.Port)))
if err != nil {
log.Trace("error during Dial for host %s: %w", host, err)
continue
}
conn.SetTimeout(time.Second * 10)

if source.SecurityProtocol == SecurityProtocolStartTLS {
if err = conn.StartTLS(tlsConfig); err != nil {
conn.Close()
log.Trace("error during StartTLS for host %s: %w", host, err)
continue
}
}
}

return conn, nil
// All servers were unreachable
return nil, fmt.Errorf("dial failed for all provided servers: %s", hostList)
}

func bindUser(l *ldap.Conn, userDN, passwd string) error {
Expand Down Expand Up @@ -257,7 +277,7 @@ func (source *Source) SearchEntry(name, passwd string, directBind bool) *SearchR
}
l, err := dial(source)
if err != nil {
log.Error("LDAP Connect error, %s:%v", source.Host, err)
log.Error("LDAP Connect error, %s:%v", source.HostList, err)
source.Enabled = false
return nil
}
Expand Down Expand Up @@ -421,7 +441,7 @@ func (source *Source) UsePagedSearch() bool {
func (source *Source) SearchEntries() ([]*SearchResult, error) {
l, err := dial(source)
if err != nil {
log.Error("LDAP Connect error, %s:%v", source.Host, err)
log.Error("LDAP Connect error, %s:%v", source.HostList, err)
source.Enabled = false
return nil, err
}
Expand Down

0 comments on commit 75ee729

Please sign in to comment.