Skip to content

Commit

Permalink
Handle auth error (#226)
Browse files Browse the repository at this point in the history
  • Loading branch information
Aiee authored Oct 20, 2022
1 parent a3f092c commit 7d1ce6a
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 17 deletions.
28 changes: 26 additions & 2 deletions client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,11 @@ func TestConnection(t *testing.T) {
t.Fatalf("fail to authenticate, username: %s, password: %s, %s", username, password, authErr.Error())
}

if authResp.GetErrorCode() != nebula.ErrorCode_SUCCEEDED {
t.Fatalf("failed to authenticate, error code: %d, error msg: %s",
authResp.GetErrorCode(), authResp.GetErrorMsg())
}

sessionID := authResp.GetSessionID()

defer logoutAndClose(conn, sessionID)
Expand Down Expand Up @@ -127,6 +132,11 @@ func TestConnectionIPv6(t *testing.T) {
t.Fatalf("fail to authenticate, username: %s, password: %s, %s", username, password, authErr.Error())
}

if authResp.GetErrorCode() != nebula.ErrorCode_SUCCEEDED {
t.Fatalf("failed to authenticate, error code: %d, error msg: %s",
authResp.GetErrorCode(), authResp.GetErrorMsg())
}

sessionID := authResp.GetSessionID()

defer logoutAndClose(conn, sessionID)
Expand Down Expand Up @@ -250,8 +260,12 @@ func TestAuthentication(t *testing.T) {
}
defer conn.close()

_, authErr := conn.authenticate(username, password)
assert.EqualError(t, authErr, "fail to authenticate, error: User not exist")
resp, authErr := conn.authenticate(username, password)
if authErr != nil {
t.Fatalf("fail to authenticate, username: %s, password: %s, %s", username, password, authErr.Error())
}

assert.Equal(t, string(resp.GetErrorMsg()), "User not exist")
}

func TestInvalidHostTimeout(t *testing.T) {
Expand Down Expand Up @@ -1398,6 +1412,11 @@ func prepareSpace(spaceName string) error {
return fmt.Errorf("fail to authenticate, username: %s, password: %s, %s", username, password, authErr.Error())
}

if authResp.GetErrorCode() != nebula.ErrorCode_SUCCEEDED {
return fmt.Errorf("failed to authenticate, error code: %d, error msg: %s",
authResp.GetErrorCode(), authResp.GetErrorMsg())
}

sessionID := authResp.GetSessionID()

defer logoutAndClose(conn, sessionID)
Expand Down Expand Up @@ -1430,6 +1449,11 @@ func dropSpace(spaceName string) error {
return fmt.Errorf("fail to authenticate, username: %s, password: %s, %s", username, password, authErr.Error())
}

if authResp.GetErrorCode() != nebula.ErrorCode_SUCCEEDED {
return fmt.Errorf("failed to authenticate, error code: %d, error msg: %s",
authResp.GetErrorCode(), authResp.GetErrorMsg())
}

sessionID := authResp.GetSessionID()

defer logoutAndClose(conn, sessionID)
Expand Down
6 changes: 2 additions & 4 deletions connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,10 +106,8 @@ func (cn *connection) authenticate(username, password string) (*graph.AuthRespon
}
return nil, err
}
if resp.ErrorCode != nebula.ErrorCode_SUCCEEDED {
return nil, fmt.Errorf("fail to authenticate, error: %s", resp.ErrorMsg)
}
return resp, err

return resp, nil
}

func (cn *connection) execute(sessionID int64, stmt string) (*graph.ExecutionResponse, error) {
Expand Down
8 changes: 7 additions & 1 deletion connection_pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ func (pool *ConnectionPool) GetSession(username, password string) (*Session, err
}
// Authenticate
resp, err := conn.authenticate(username, password)
if err != nil || resp.GetErrorCode() != nebula.ErrorCode_SUCCEEDED {
if err != nil {
// if authentication failed, put connection back
pool.rwLock.Lock()
defer pool.rwLock.Unlock()
Expand All @@ -121,6 +121,12 @@ func (pool *ConnectionPool) GetSession(username, password string) (*Session, err
return nil, err
}

// Check auth response
if resp.GetErrorCode() != nebula.ErrorCode_SUCCEEDED {
return nil, fmt.Errorf("failed to authenticate, error code: %d, error msg: %s",
resp.GetErrorCode(), resp.GetErrorMsg())
}

sessID := resp.GetSessionID()
timezoneOffset := resp.GetTimeZoneOffsetSeconds()
timezoneName := resp.GetTimeZoneName()
Expand Down
24 changes: 14 additions & 10 deletions session_pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -310,9 +310,14 @@ func (pool *SessionPool) newSession() (*Session, error) {

// If the authentication failed, close the session pool because the pool must have a valid user to work
if authResp.GetErrorCode() != 0 {
pool.Close()
return nil, fmt.Errorf("failed to authenticate the user, error code: %d, error message: %s, the pool has been closed",
authResp.ErrorCode, authResp.ErrorMsg)
if authResp.GetErrorCode() == nebula.ErrorCode_E_BAD_USERNAME_PASSWORD ||
authResp.GetErrorCode() == nebula.ErrorCode_E_USER_NOT_FOUND {
pool.Close()
return nil, fmt.Errorf(
"failed to authenticate the user, error code: %d, error message: %s, the pool has been closed",
authResp.ErrorCode, authResp.ErrorMsg)
}
return nil, fmt.Errorf("failed to create a new session: %s", authResp.GetErrorMsg())

This comment has been minimized.

Copy link
@lee-qiu

lee-qiu Jan 5, 2023

why not close the connection when auth failed?

}

sessID := authResp.GetSessionID()
Expand All @@ -327,19 +332,18 @@ func (pool *SessionPool) newSession() (*Session, error) {
log: pool.log,
timezoneInfo: timezoneInfo{timezoneOffset, timezoneName},
}
err = newSession.Ping()
if err != nil {
return nil, err
}

// Switch to the default space
stmt := fmt.Sprintf("USE %s", pool.conf.spaceName)
createSpaceResp, err := newSession.connection.execute(newSession.sessionID, stmt)
useSpaceResp, err := newSession.connection.execute(newSession.sessionID, stmt)
if err != nil {
return nil, err

This comment has been minimized.

Copy link
@lee-qiu

lee-qiu Jan 5, 2023

ditto

}
if createSpaceResp.GetErrorCode() != nebula.ErrorCode_SUCCEEDED {

if useSpaceResp.GetErrorCode() != nebula.ErrorCode_SUCCEEDED {
newSession.connection.close()
return nil, fmt.Errorf("failed to use space %s: %s",
pool.conf.spaceName, createSpaceResp.GetErrorMsg())
pool.conf.spaceName, useSpaceResp.GetErrorMsg())

This comment has been minimized.

Copy link
@lee-qiu

lee-qiu Jan 5, 2023

ditto

}
return &newSession, nil
}
Expand Down

0 comments on commit 7d1ce6a

Please sign in to comment.