Skip to content

Commit

Permalink
add tests
Browse files Browse the repository at this point in the history
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
  • Loading branch information
bcmmbaga committed Sep 18, 2024
1 parent 416ef15 commit 0a941f4
Showing 1 changed file with 169 additions and 13 deletions.
182 changes: 169 additions & 13 deletions management/server/http/peers_handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net"
"net/http"
Expand All @@ -12,20 +13,29 @@ import (
"time"

"github.com/gorilla/mux"

nbgroup "github.com/netbirdio/netbird/management/server/group"
"github.com/netbirdio/netbird/management/server/http/api"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"golang.org/x/exp/maps"

"github.com/netbirdio/netbird/management/server/jwtclaims"

"github.com/magiconair/properties/assert"
"github.com/stretchr/testify/assert"

"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/mock_server"
)

const testPeerID = "test_peer"
const noUpdateChannelTestPeerID = "no-update-channel"
type ctxKey string

const (
testPeerID = "test_peer"
noUpdateChannelTestPeerID = "no-update-channel"

adminUser = "admin_user"
regularUser = "regular_user"
userIDKey ctxKey = "user_id"
)

func initTestMetaData(peers ...*nbpeer.Peer) *PeersHandler {
return &PeersHandler{
Expand Down Expand Up @@ -60,21 +70,53 @@ func initTestMetaData(peers ...*nbpeer.Peer) *PeersHandler {
return "netbird.selfhosted"
},
GetAccountFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
user := server.NewAdminUser("test_user")
return &server.Account{
peersMap := make(map[string]*nbpeer.Peer)
for _, peer := range peers {
peersMap[peer.ID] = peer.Copy()
}

policy := &server.Policy{
ID: "policy",
AccountID: claims.AccountId,
Name: "policy",
Enabled: true,
Rules: []*server.PolicyRule{
{
ID: "rule",
Name: "rule",
Enabled: true,
Action: "accept",
Destinations: []string{"group1"},
Sources: []string{"group1"},
Bidirectional: true,
Protocol: "all",
Ports: []string{"80"},
},
},
}

account := &server.Account{
Id: claims.AccountId,
Domain: "hotmail.com",
Peers: map[string]*nbpeer.Peer{
peers[0].ID: peers[0],
peers[1].ID: peers[1],
},
Peers: peersMap,
Users: map[string]*server.User{
"test_user": user,
adminUser: server.NewAdminUser(adminUser),
regularUser: server.NewRegularUser(regularUser),
},
Groups: map[string]*nbgroup.Group{
"group1": {
ID: "group1",
AccountID: claims.AccountId,
Name: "group1",
Issued: "api",
Peers: maps.Keys(peersMap),
},
},
Settings: &server.Settings{
PeerLoginExpirationEnabled: true,
PeerLoginExpiration: time.Hour,
},
Policies: []*server.Policy{policy},
Network: &server.Network{
Identifier: "ciclqisab2ss43jdn8q0",
Net: net.IPNet{
Expand All @@ -83,7 +125,9 @@ func initTestMetaData(peers ...*nbpeer.Peer) *PeersHandler {
},
Serial: 51,
},
}, user, nil
}

return account, account.Users[claims.UserId], nil
},
HasConnectedChannelFunc: func(peerID string) bool {
statuses := make(map[string]struct{})
Expand All @@ -99,8 +143,9 @@ func initTestMetaData(peers ...*nbpeer.Peer) *PeersHandler {
},
claimsExtractor: jwtclaims.NewClaimsExtractor(
jwtclaims.WithFromRequestContext(func(r *http.Request) jwtclaims.AuthorizationClaims {
userID := r.Context().Value(userIDKey).(string)
return jwtclaims.AuthorizationClaims{
UserId: "test_user",
UserId: userID,
Domain: "hotmail.com",
AccountId: "test_id",
}
Expand Down Expand Up @@ -197,6 +242,8 @@ func TestGetPeers(t *testing.T) {

recorder := httptest.NewRecorder()
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
ctx := context.WithValue(context.Background(), userIDKey, "admin_user")
req = req.WithContext(ctx)

router := mux.NewRouter()
router.HandleFunc("/api/peers/", p.GetAllPeers).Methods("GET")
Expand Down Expand Up @@ -251,3 +298,112 @@ func TestGetPeers(t *testing.T) {
})
}
}

func TestGetAccessiblePeers(t *testing.T) {
peer1 := &nbpeer.Peer{
ID: "peer1",
Key: "key1",
IP: net.ParseIP("100.64.0.1"),
Status: &nbpeer.PeerStatus{Connected: true},
Name: "peer1",
LoginExpirationEnabled: false,
UserID: regularUser,
}

peer2 := &nbpeer.Peer{
ID: "peer2",
Key: "key2",
IP: net.ParseIP("100.64.0.2"),
Status: &nbpeer.PeerStatus{Connected: true},
Name: "peer2",
LoginExpirationEnabled: false,
UserID: adminUser,
}

peer3 := &nbpeer.Peer{
ID: "peer3",
Key: "key3",
IP: net.ParseIP("100.64.0.3"),
Status: &nbpeer.PeerStatus{Connected: true},
Name: "peer3",
LoginExpirationEnabled: false,
UserID: regularUser,
}

p := initTestMetaData(peer1, peer2, peer3)

tt := []struct {
name string
peerID string
callerUserID string
expectedStatus int
expectedPeers []string
}{
{
name: "non admin user can access owned peer",
peerID: "peer1",
callerUserID: regularUser,
expectedStatus: http.StatusOK,
expectedPeers: []string{"peer2", "peer3"},
},
{
name: "non admin user can't access unowned peer",
peerID: "peer2",
callerUserID: regularUser,
expectedStatus: http.StatusOK,
expectedPeers: []string{},
},
{
name: "admin user can access owned peer",
peerID: "peer2",
callerUserID: adminUser,
expectedStatus: http.StatusOK,
expectedPeers: []string{"peer1", "peer3"},
},
{
name: "admin user can access unowned peer",
peerID: "peer3",
callerUserID: adminUser,
expectedStatus: http.StatusOK,
expectedPeers: []string{"peer1", "peer2"},
},
}

for _, tc := range tt {
t.Run(tc.name, func(t *testing.T) {

recorder := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, fmt.Sprintf("/api/peers/%s/accessible-peers", tc.peerID), nil)
ctx := context.WithValue(context.Background(), userIDKey, tc.callerUserID)
req = req.WithContext(ctx)

router := mux.NewRouter()
router.HandleFunc("/api/peers/{peerId}/accessible-peers", p.GetAccessiblePeers).Methods("GET")
router.ServeHTTP(recorder, req)

res := recorder.Result()
if res.StatusCode != tc.expectedStatus {
t.Fatalf("handler returned wrong status code: got %v want %v", res.StatusCode, tc.expectedStatus)
}

body, err := io.ReadAll(res.Body)
if err != nil {
t.Fatalf("failed to read response body: %v", err)
}
defer res.Body.Close()

var accessiblePeers []api.AccessiblePeer
err = json.Unmarshal(body, &accessiblePeers)
if err != nil {
t.Fatalf("failed to unmarshal response: %v", err)
}

peerIDs := make([]string, len(accessiblePeers))
for i, peer := range accessiblePeers {
peerIDs[i] = peer.Id
}

assert.ElementsMatch(t, peerIDs, tc.expectedPeers)
})
}
}

0 comments on commit 0a941f4

Please sign in to comment.