From 35c892aea365fa9e3e710ecc45dc7c0a353b8c71 Mon Sep 17 00:00:00 2001 From: Bethuel Mmbaga Date: Fri, 20 Sep 2024 12:36:58 +0300 Subject: [PATCH] [management] Restrict accessible peers to user-owned peers for non-admins (#2618) * Restrict accessible peers to user-owned peers for non-admin users Signed-off-by: bcmmbaga * add tests Signed-off-by: bcmmbaga * add service user test Signed-off-by: bcmmbaga * reuse account from token Signed-off-by: bcmmbaga * return error when peer not found Signed-off-by: bcmmbaga --------- Signed-off-by: bcmmbaga --- management/server/http/peers_handler.go | 20 +- management/server/http/peers_handler_test.go | 194 +++++++++++++++++-- 2 files changed, 198 insertions(+), 16 deletions(-) diff --git a/management/server/http/peers_handler.go b/management/server/http/peers_handler.go index 1487bbc3944..5a2190d83fa 100644 --- a/management/server/http/peers_handler.go +++ b/management/server/http/peers_handler.go @@ -7,8 +7,6 @@ import ( "net/http" "github.com/gorilla/mux" - log "github.com/sirupsen/logrus" - "github.com/netbirdio/netbird/management/server" nbgroup "github.com/netbirdio/netbird/management/server/group" "github.com/netbirdio/netbird/management/server/http/api" @@ -16,6 +14,7 @@ import ( "github.com/netbirdio/netbird/management/server/jwtclaims" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/status" + log "github.com/sirupsen/logrus" ) // PeersHandler is a handler that returns peers of the account @@ -215,7 +214,7 @@ func (h *PeersHandler) setApprovalRequiredFlag(respBody []*api.PeerBatch, approv // GetAccessiblePeers returns a list of all peers that the specified peer can connect to within the network. func (h *PeersHandler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, _, err := h.accountManager.GetAccountFromToken(r.Context(), claims) + account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) if err != nil { util.WriteError(r.Context(), err, w) return @@ -228,6 +227,21 @@ func (h *PeersHandler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request return } + // If the user is regular user and does not own the peer + // with the given peerID return an empty list + if !user.HasAdminPower() && !user.IsServiceUser { + peer, ok := account.Peers[peerID] + if !ok { + util.WriteError(r.Context(), status.Errorf(status.NotFound, "peer not found"), w) + return + } + + if peer.UserID != user.Id { + util.WriteJSONObject(r.Context(), w, []api.AccessiblePeer{}) + return + } + } + dnsDomain := h.accountManager.GetDNSDomain() validPeers, err := h.accountManager.GetValidatedPeers(account) diff --git a/management/server/http/peers_handler_test.go b/management/server/http/peers_handler_test.go index 153c8f03a61..dae264fff11 100644 --- a/management/server/http/peers_handler_test.go +++ b/management/server/http/peers_handler_test.go @@ -4,6 +4,7 @@ import ( "bytes" "context" "encoding/json" + "fmt" "io" "net" "net/http" @@ -12,20 +13,30 @@ import ( "time" "github.com/gorilla/mux" - + nbgroup "github.com/netbirdio/netbird/management/server/group" "github.com/netbirdio/netbird/management/server/http/api" nbpeer "github.com/netbirdio/netbird/management/server/peer" + "golang.org/x/exp/maps" "github.com/netbirdio/netbird/management/server/jwtclaims" - "github.com/magiconair/properties/assert" + "github.com/stretchr/testify/assert" "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/mock_server" ) -const testPeerID = "test_peer" -const noUpdateChannelTestPeerID = "no-update-channel" +type ctxKey string + +const ( + testPeerID = "test_peer" + noUpdateChannelTestPeerID = "no-update-channel" + + adminUser = "admin_user" + regularUser = "regular_user" + serviceUser = "service_user" + userIDKey ctxKey = "user_id" +) func initTestMetaData(peers ...*nbpeer.Peer) *PeersHandler { return &PeersHandler{ @@ -60,21 +71,57 @@ func initTestMetaData(peers ...*nbpeer.Peer) *PeersHandler { return "netbird.selfhosted" }, GetAccountFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) { - user := server.NewAdminUser("test_user") - return &server.Account{ + peersMap := make(map[string]*nbpeer.Peer) + for _, peer := range peers { + peersMap[peer.ID] = peer.Copy() + } + + policy := &server.Policy{ + ID: "policy", + AccountID: claims.AccountId, + Name: "policy", + Enabled: true, + Rules: []*server.PolicyRule{ + { + ID: "rule", + Name: "rule", + Enabled: true, + Action: "accept", + Destinations: []string{"group1"}, + Sources: []string{"group1"}, + Bidirectional: true, + Protocol: "all", + Ports: []string{"80"}, + }, + }, + } + + srvUser := server.NewRegularUser(serviceUser) + srvUser.IsServiceUser = true + + account := &server.Account{ Id: claims.AccountId, Domain: "hotmail.com", - Peers: map[string]*nbpeer.Peer{ - peers[0].ID: peers[0], - peers[1].ID: peers[1], - }, + Peers: peersMap, Users: map[string]*server.User{ - "test_user": user, + adminUser: server.NewAdminUser(adminUser), + regularUser: server.NewRegularUser(regularUser), + serviceUser: srvUser, + }, + Groups: map[string]*nbgroup.Group{ + "group1": { + ID: "group1", + AccountID: claims.AccountId, + Name: "group1", + Issued: "api", + Peers: maps.Keys(peersMap), + }, }, Settings: &server.Settings{ PeerLoginExpirationEnabled: true, PeerLoginExpiration: time.Hour, }, + Policies: []*server.Policy{policy}, Network: &server.Network{ Identifier: "ciclqisab2ss43jdn8q0", Net: net.IPNet{ @@ -83,7 +130,9 @@ func initTestMetaData(peers ...*nbpeer.Peer) *PeersHandler { }, Serial: 51, }, - }, user, nil + } + + return account, account.Users[claims.UserId], nil }, HasConnectedChannelFunc: func(peerID string) bool { statuses := make(map[string]struct{}) @@ -99,8 +148,9 @@ func initTestMetaData(peers ...*nbpeer.Peer) *PeersHandler { }, claimsExtractor: jwtclaims.NewClaimsExtractor( jwtclaims.WithFromRequestContext(func(r *http.Request) jwtclaims.AuthorizationClaims { + userID := r.Context().Value(userIDKey).(string) return jwtclaims.AuthorizationClaims{ - UserId: "test_user", + UserId: userID, Domain: "hotmail.com", AccountId: "test_id", } @@ -197,6 +247,8 @@ func TestGetPeers(t *testing.T) { recorder := httptest.NewRecorder() req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody) + ctx := context.WithValue(context.Background(), userIDKey, "admin_user") + req = req.WithContext(ctx) router := mux.NewRouter() router.HandleFunc("/api/peers/", p.GetAllPeers).Methods("GET") @@ -251,3 +303,119 @@ func TestGetPeers(t *testing.T) { }) } } + +func TestGetAccessiblePeers(t *testing.T) { + peer1 := &nbpeer.Peer{ + ID: "peer1", + Key: "key1", + IP: net.ParseIP("100.64.0.1"), + Status: &nbpeer.PeerStatus{Connected: true}, + Name: "peer1", + LoginExpirationEnabled: false, + UserID: regularUser, + } + + peer2 := &nbpeer.Peer{ + ID: "peer2", + Key: "key2", + IP: net.ParseIP("100.64.0.2"), + Status: &nbpeer.PeerStatus{Connected: true}, + Name: "peer2", + LoginExpirationEnabled: false, + UserID: adminUser, + } + + peer3 := &nbpeer.Peer{ + ID: "peer3", + Key: "key3", + IP: net.ParseIP("100.64.0.3"), + Status: &nbpeer.PeerStatus{Connected: true}, + Name: "peer3", + LoginExpirationEnabled: false, + UserID: regularUser, + } + + p := initTestMetaData(peer1, peer2, peer3) + + tt := []struct { + name string + peerID string + callerUserID string + expectedStatus int + expectedPeers []string + }{ + { + name: "non admin user can access owned peer", + peerID: "peer1", + callerUserID: regularUser, + expectedStatus: http.StatusOK, + expectedPeers: []string{"peer2", "peer3"}, + }, + { + name: "non admin user can't access unowned peer", + peerID: "peer2", + callerUserID: regularUser, + expectedStatus: http.StatusOK, + expectedPeers: []string{}, + }, + { + name: "admin user can access owned peer", + peerID: "peer2", + callerUserID: adminUser, + expectedStatus: http.StatusOK, + expectedPeers: []string{"peer1", "peer3"}, + }, + { + name: "admin user can access unowned peer", + peerID: "peer3", + callerUserID: adminUser, + expectedStatus: http.StatusOK, + expectedPeers: []string{"peer1", "peer2"}, + }, + { + name: "service user can access unowned peer", + peerID: "peer3", + callerUserID: serviceUser, + expectedStatus: http.StatusOK, + expectedPeers: []string{"peer1", "peer2"}, + }, + } + + for _, tc := range tt { + t.Run(tc.name, func(t *testing.T) { + + recorder := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, fmt.Sprintf("/api/peers/%s/accessible-peers", tc.peerID), nil) + ctx := context.WithValue(context.Background(), userIDKey, tc.callerUserID) + req = req.WithContext(ctx) + + router := mux.NewRouter() + router.HandleFunc("/api/peers/{peerId}/accessible-peers", p.GetAccessiblePeers).Methods("GET") + router.ServeHTTP(recorder, req) + + res := recorder.Result() + if res.StatusCode != tc.expectedStatus { + t.Fatalf("handler returned wrong status code: got %v want %v", res.StatusCode, tc.expectedStatus) + } + + body, err := io.ReadAll(res.Body) + if err != nil { + t.Fatalf("failed to read response body: %v", err) + } + defer res.Body.Close() + + var accessiblePeers []api.AccessiblePeer + err = json.Unmarshal(body, &accessiblePeers) + if err != nil { + t.Fatalf("failed to unmarshal response: %v", err) + } + + peerIDs := make([]string, len(accessiblePeers)) + for i, peer := range accessiblePeers { + peerIDs[i] = peer.Id + } + + assert.ElementsMatch(t, peerIDs, tc.expectedPeers) + }) + } +}