Skip to content

Commit

Permalink
Replace gRPC errors in business logic with internal ones (#558)
Browse files Browse the repository at this point in the history
  • Loading branch information
braginini authored Nov 11, 2022
1 parent 1db4027 commit 509d23c
Show file tree
Hide file tree
Showing 35 changed files with 768 additions and 847 deletions.
70 changes: 34 additions & 36 deletions management/server/account.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,11 @@ import (
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/server/idp"
"github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/netbirdio/netbird/management/server/status"
"github.com/netbirdio/netbird/route"
gocache "github.com/patrickmn/go-cache"
"github.com/rs/xid"
log "github.com/sirupsen/logrus"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"math/rand"
"net"
"net/netip"
Expand Down Expand Up @@ -52,7 +51,7 @@ type AccountManager interface {
SaveUser(accountID string, key *User) (*UserInfo, error)
GetSetupKey(accountID, userID, keyID string) (*SetupKey, error)
GetAccountByUserOrAccountID(userID, accountID, domain string) (*Account, error)
GetAccountFromToken(claims jwtclaims.AuthorizationClaims) (*Account, error)
GetAccountFromToken(claims jwtclaims.AuthorizationClaims) (*Account, *User, error)
IsUserAdmin(claims jwtclaims.AuthorizationClaims) (bool, error)
AccountExists(accountId string) (*bool, error)
GetPeer(peerKey string) (*Peer, error)
Expand Down Expand Up @@ -265,14 +264,14 @@ func (a *Account) FindPeerByPubKey(peerPubKey string) (*Peer, error) {
}
}

return nil, status.Errorf(codes.NotFound, "peer with the public key %s not found", peerPubKey)
return nil, status.Errorf(status.NotFound, "peer with the public key %s not found", peerPubKey)
}

// FindUser looks for a given user in the Account or returns error if user wasn't found.
func (a *Account) FindUser(userID string) (*User, error) {
user := a.Users[userID]
if user == nil {
return nil, Errorf(UserNotFound, "user %s not found", userID)
return nil, status.Errorf(status.NotFound, "user %s not found", userID)
}

return user, nil
Expand All @@ -282,7 +281,7 @@ func (a *Account) FindUser(userID string) (*User, error) {
func (a *Account) FindSetupKey(setupKey string) (*SetupKey, error) {
key := a.SetupKeys[setupKey]
if key == nil {
return nil, Errorf(SetupKeyNotFound, "setup key not found")
return nil, status.Errorf(status.NotFound, "setup key not found")
}

return key, nil
Expand Down Expand Up @@ -458,14 +457,14 @@ func (am *DefaultAccountManager) newAccount(userID, domain string) (*Account, er
if err == nil {
log.Warnf("an account with ID already exists, retrying...")
continue
} else if statusErr.Code() == codes.NotFound {
} else if statusErr.Type() == status.NotFound {
return newAccountWithId(accountId, userID, domain), nil
} else {
return nil, err
}
}

return nil, status.Errorf(codes.Internal, "error while creating new account")
return nil, status.Errorf(status.Internal, "error while creating new account")
}

func (am *DefaultAccountManager) warmupIDPCache() error {
Expand All @@ -492,7 +491,7 @@ func (am *DefaultAccountManager) GetAccountByUserOrAccountID(userID, accountID,
} else if userID != "" {
account, err := am.GetOrCreateAccountByUser(userID, domain)
if err != nil {
return nil, status.Errorf(codes.NotFound, "account not found using user id: %s", userID)
return nil, status.Errorf(status.NotFound, "account not found using user id: %s", userID)
}
err = am.addAccountIDToIDPAppMeta(userID, account)
if err != nil {
Expand All @@ -501,7 +500,7 @@ func (am *DefaultAccountManager) GetAccountByUserOrAccountID(userID, accountID,
return account, nil
}

return nil, status.Errorf(codes.NotFound, "no valid user or account Id provided")
return nil, status.Errorf(status.NotFound, "no valid user or account Id provided")
}

func isNil(i idp.Manager) bool {
Expand Down Expand Up @@ -531,11 +530,7 @@ func (am *DefaultAccountManager) addAccountIDToIDPAppMeta(userID string, account
}

if err != nil {
return status.Errorf(
codes.Internal,
"updating user's app metadata failed with: %v",
err,
)
return status.Errorf(status.Internal, "updating user's app metadata failed with: %v", err)
}
// refresh cache to reflect the update
_, err = am.refreshCache(account.Id)
Expand Down Expand Up @@ -662,11 +657,8 @@ func (am *DefaultAccountManager) lookupCache(accountUsers map[string]struct{}, a
}

// updateAccountDomainAttributes updates the account domain attributes and then, saves the account
func (am *DefaultAccountManager) updateAccountDomainAttributes(
account *Account,
claims jwtclaims.AuthorizationClaims,
primaryDomain bool,
) error {
func (am *DefaultAccountManager) updateAccountDomainAttributes(account *Account, claims jwtclaims.AuthorizationClaims,
primaryDomain bool) error {
account.IsDomainPrimaryAccount = primaryDomain

lowerDomain := strings.ToLower(claims.Domain)
Expand All @@ -681,7 +673,7 @@ func (am *DefaultAccountManager) updateAccountDomainAttributes(

err := am.Store.SaveAccount(account)
if err != nil {
return status.Errorf(codes.Internal, "failed saving updated account")
return err
}
return nil
}
Expand Down Expand Up @@ -723,10 +715,7 @@ func (am *DefaultAccountManager) handleExistingUserAccount(

// handleNewUserAccount validates if there is an existing primary account for the domain, if so it adds the new user to that account,
// otherwise it will create a new account and make it primary account for the domain.
func (am *DefaultAccountManager) handleNewUserAccount(
domainAcc *Account,
claims jwtclaims.AuthorizationClaims,
) (*Account, error) {
func (am *DefaultAccountManager) handleNewUserAccount(domainAcc *Account, claims jwtclaims.AuthorizationClaims) (*Account, error) {
var (
account *Account
err error
Expand All @@ -738,7 +727,7 @@ func (am *DefaultAccountManager) handleNewUserAccount(
account.Users[claims.UserId] = NewRegularUser(claims.UserId)
err = am.Store.SaveAccount(account)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed saving updated account")
return nil, err
}
} else {
account, err = am.newAccount(claims.UserId, lowerDomain)
Expand Down Expand Up @@ -773,7 +762,7 @@ func (am *DefaultAccountManager) redeemInvite(account *Account, userID string) e
}

if user == nil {
return status.Errorf(codes.NotFound, "user %s not found in the IdP", userID)
return status.Errorf(status.NotFound, "user %s not found in the IdP", userID)
}

if user.AppMetadata.WTPendingInvite != nil && *user.AppMetadata.WTPendingInvite {
Expand All @@ -794,7 +783,7 @@ func (am *DefaultAccountManager) redeemInvite(account *Account, userID string) e
}

// GetAccountFromToken returns an account associated with this token
func (am *DefaultAccountManager) GetAccountFromToken(claims jwtclaims.AuthorizationClaims) (*Account, error) {
func (am *DefaultAccountManager) GetAccountFromToken(claims jwtclaims.AuthorizationClaims) (*Account, *User, error) {

if am.singleAccountMode && am.singleAccountModeDomain != "" {
// This section is mostly related to self-hosted installations.
Expand All @@ -806,15 +795,21 @@ func (am *DefaultAccountManager) GetAccountFromToken(claims jwtclaims.Authorizat

account, err := am.getAccountWithAuthorizationClaims(claims)
if err != nil {
return nil, err
return nil, nil, err
}

user := account.Users[claims.UserId]
if user == nil {
// this is not really possible because we got an account by user ID
return nil, nil, status.Errorf(status.NotFound, "user %s not found", claims.UserId)
}

err = am.redeemInvite(account, claims.UserId)
if err != nil {
return nil, err
return nil, nil, err
}

return account, nil
return account, user, nil
}

// getAccountWithAuthorizationClaims retrievs an account using JWT Claims.
Expand Down Expand Up @@ -857,9 +852,12 @@ func (am *DefaultAccountManager) getAccountWithAuthorizationClaims(claims jwtcla

// We checked if the domain has a primary account already
domainAccount, err := am.Store.GetAccountByPrivateDomain(claims.Domain)
accStatus, _ := status.FromError(err)
if accStatus.Code() != codes.OK && accStatus.Code() != codes.NotFound {
return nil, err
if err != nil {
// if NotFound we are good to continue, otherwise return error
e, ok := status.FromError(err)
if !ok || e.Type() != status.NotFound {
return nil, err
}
}

account, err := am.Store.GetAccountByUser(claims.UserId)
Expand All @@ -869,7 +867,7 @@ func (am *DefaultAccountManager) getAccountWithAuthorizationClaims(claims jwtcla
return nil, err
}
return account, nil
} else if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound {
} else if s, ok := status.FromError(err); ok && s.Type() == status.NotFound {
return am.handleNewUserAccount(domainAccount, claims)
} else {
// other error
Expand All @@ -891,7 +889,7 @@ func (am *DefaultAccountManager) AccountExists(accountID string) (*bool, error)
var res bool
_, err := am.Store.GetAccount(accountID)
if err != nil {
if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound {
if s, ok := status.FromError(err); ok && s.Type() == status.NotFound {
res = false
return &res, nil
} else {
Expand Down
2 changes: 1 addition & 1 deletion management/server/account_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,7 @@ func TestDefaultAccountManager_GetAccountFromToken(t *testing.T) {
testCase.inputClaims.AccountId = initAccount.Id
}

account, err := manager.GetAccountFromToken(testCase.inputClaims)
account, _, err := manager.GetAccountFromToken(testCase.inputClaims)
require.NoError(t, err, "support function failed")
verifyNewAccountHasDefaultFields(t, account, testCase.expectedCreatedBy, testCase.inputClaims.Domain, testCase.expectedUsers)
verifyCanAddPeerToAccount(t, manager, account, testCase.expectedCreatedBy)
Expand Down
61 changes: 0 additions & 61 deletions management/server/error.go

This file was deleted.

19 changes: 7 additions & 12 deletions management/server/file_store.go
Original file line number Diff line number Diff line change
@@ -1,16 +1,14 @@
package server

import (
"github.com/netbirdio/netbird/management/server/status"
log "github.com/sirupsen/logrus"
"os"
"path/filepath"
"strings"
"sync"
"time"

"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"

"github.com/netbirdio/netbird/util"
)

Expand Down Expand Up @@ -192,10 +190,7 @@ func (s *FileStore) GetAccountByPrivateDomain(domain string) (*Account, error) {

accountID, accountIDFound := s.PrivateDomain2AccountID[strings.ToLower(domain)]
if !accountIDFound {
return nil, status.Errorf(
codes.NotFound,
"account not found: provided domain is not registered or is not private",
)
return nil, status.Errorf(status.NotFound, "account not found: provided domain is not registered or is not private")
}

account, err := s.getAccount(accountID)
Expand All @@ -213,7 +208,7 @@ func (s *FileStore) GetAccountBySetupKey(setupKey string) (*Account, error) {

accountID, accountIDFound := s.SetupKeyID2AccountID[strings.ToUpper(setupKey)]
if !accountIDFound {
return nil, status.Errorf(codes.NotFound, "account not found: provided setup key doesn't exists")
return nil, status.Errorf(status.NotFound, "account not found: provided setup key doesn't exists")
}

account, err := s.getAccount(accountID)
Expand All @@ -239,7 +234,7 @@ func (s *FileStore) GetAllAccounts() (all []*Account) {
func (s *FileStore) getAccount(accountID string) (*Account, error) {
account, accountFound := s.Accounts[accountID]
if !accountFound {
return nil, status.Errorf(codes.NotFound, "account not found")
return nil, status.Errorf(status.NotFound, "account not found")
}

return account, nil
Expand All @@ -265,7 +260,7 @@ func (s *FileStore) GetAccountByUser(userID string) (*Account, error) {

accountID, accountIDFound := s.UserID2AccountID[userID]
if !accountIDFound {
return nil, status.Errorf(codes.NotFound, "account not found")
return nil, status.Errorf(status.NotFound, "account not found")
}

account, err := s.getAccount(accountID)
Expand All @@ -283,7 +278,7 @@ func (s *FileStore) GetAccountByPeerPubKey(peerKey string) (*Account, error) {

accountID, accountIDFound := s.PeerKeyID2AccountID[peerKey]
if !accountIDFound {
return nil, status.Errorf(codes.NotFound, "Provided peer key doesn't exists %s", peerKey)
return nil, status.Errorf(status.NotFound, "provided peer key doesn't exists %s", peerKey)
}

account, err := s.getAccount(accountID)
Expand Down Expand Up @@ -322,7 +317,7 @@ func (s *FileStore) SavePeerStatus(accountID, peerKey string, peerStatus PeerSta

peer := account.Peers[peerKey]
if peer == nil {
return status.Errorf(codes.NotFound, "peer %s not found", peerKey)
return status.Errorf(status.NotFound, "peer %s not found", peerKey)
}

peer.Status = &peerStatus
Expand Down
Loading

0 comments on commit 509d23c

Please sign in to comment.