Skip to content

Commit

Permalink
Merge pull request #28 from wiretrustee/test-signal-grpc
Browse files Browse the repository at this point in the history
test: add basic signal IT tests
  • Loading branch information
braginini authored Jun 19, 2021
2 parents 6465e25 + db673ed commit 3c45da5
Show file tree
Hide file tree
Showing 8 changed files with 296 additions and 17 deletions.
4 changes: 2 additions & 2 deletions cmd/signal.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import (
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
sig "github.com/wiretrustee/wiretrustee/signal"
sProto "github.com/wiretrustee/wiretrustee/signal/proto"
sigProto "github.com/wiretrustee/wiretrustee/signal/proto"
"google.golang.org/grpc"
"net"
)
Expand All @@ -30,7 +30,7 @@ var (
}
var opts []grpc.ServerOption
grpcServer := grpc.NewServer(opts...)
sProto.RegisterSignalExchangeServer(grpcServer, sig.NewServer())
sigProto.RegisterSignalExchangeServer(grpcServer, sig.NewServer())
log.Printf("started server: localhost:%v", port)
if err := grpcServer.Serve(lis); err != nil {
log.Fatalf("failed to serve: %v", err)
Expand Down
4 changes: 3 additions & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@ go 1.16

require (
github.com/cenkalti/backoff/v4 v4.1.0
github.com/golang/protobuf v1.4.3
github.com/golang/protobuf v1.5.2
github.com/google/nftables v0.0.0-20201230142148-715e31cb3c31
github.com/onsi/ginkgo v1.16.4
github.com/onsi/gomega v1.13.0
github.com/pion/ice/v2 v2.1.7
github.com/sirupsen/logrus v1.7.0
github.com/spf13/cobra v1.1.3
Expand Down
55 changes: 52 additions & 3 deletions go.sum

Large diffs are not rendered by default.

15 changes: 12 additions & 3 deletions signal/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ type Client struct {
ctx context.Context
stream proto.SignalExchange_ConnectStreamClient
//waiting group to notify once stream is connected
connWg sync.WaitGroup //todo use a channel instead??
connWg *sync.WaitGroup //todo use a channel instead??
}

// Close Closes underlying connections to the Signal Exchange
Expand All @@ -55,11 +55,13 @@ func NewClient(ctx context.Context, addr string, key wgtypes.Key) (*Client, erro
return nil, err
}

var wg sync.WaitGroup
return &Client{
realClient: proto.NewSignalExchangeClient(conn),
ctx: ctx,
signalConn: conn,
key: key,
connWg: &wg,
}, nil
}

Expand Down Expand Up @@ -107,15 +109,22 @@ func (c *Client) connect(key string, msgHandler func(msg *proto.Message) error)
// add key fingerprint to the request header to be identified on the server side
md := metadata.New(map[string]string{proto.HeaderId: key})
ctx := metadata.NewOutgoingContext(c.ctx, md)
ctx, cancel := context.WithCancel(ctx)
defer cancel()

stream, err := c.realClient.ConnectStream(ctx)

c.stream = stream
if err != nil {
return err
}
// blocks
header, err := c.stream.Header()
if err != nil {
return err
}
registered := header.Get(proto.HeaderRegistered)
if len(registered) == 0 {
return fmt.Errorf("didn't receive a registration header from the Signal server whille connecting to the streams")
}
//connection established we are good to use the stream
c.connWg.Done()

Expand Down
1 change: 1 addition & 0 deletions signal/proto/constants.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@ package proto

// protocol constants, field names that can be used by both client and server
const HeaderId = "x-wiretrustee-peer-id"
const HeaderRegistered = "x-wiretrustee-peer-registered"
22 changes: 14 additions & 8 deletions signal/signal.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,20 +12,20 @@ import (
"io"
)

// SignalExchangeServer an instance of a Signal server
type SignalExchangeServer struct {
// Server an instance of a Signal server
type Server struct {
registry *peer.Registry
}

// NewServer creates a new Signal server
func NewServer() *SignalExchangeServer {
return &SignalExchangeServer{
func NewServer() *Server {
return &Server{
registry: peer.NewRegistry(),
}
}

// Send forwards a message to the signal peer
func (s *SignalExchangeServer) Send(ctx context.Context, msg *proto.EncryptedMessage) (*proto.EncryptedMessage, error) {
func (s *Server) Send(ctx context.Context, msg *proto.EncryptedMessage) (*proto.EncryptedMessage, error) {

if !s.registry.IsPeerRegistered(msg.Key) {
return nil, fmt.Errorf("unknown peer %s", msg.Key)
Expand All @@ -46,14 +46,20 @@ func (s *SignalExchangeServer) Send(ctx context.Context, msg *proto.EncryptedMes
}

// ConnectStream connects to the exchange stream
func (s *SignalExchangeServer) ConnectStream(stream proto.SignalExchange_ConnectStreamServer) error {
func (s *Server) ConnectStream(stream proto.SignalExchange_ConnectStreamServer) error {
p, err := s.connectPeer(stream)
if err != nil {
return err
}

log.Infof("peer [%s] has successfully connected", p.Id)
//needed to confirm that the peer has been registered so that the client can proceed
header := metadata.Pairs(proto.HeaderRegistered, "1")
err = stream.SendHeader(header)
if err != nil {
return err
}

log.Infof("peer [%s] has successfully connected", p.Id)
for {
msg, err := stream.Recv()
if err == io.EOF {
Expand Down Expand Up @@ -83,7 +89,7 @@ func (s *SignalExchangeServer) ConnectStream(stream proto.SignalExchange_Connect
// Handles initial Peer connection.
// Each connection must provide an ID header.
// At this moment the connecting Peer will be registered in the peer.Registry
func (s SignalExchangeServer) connectPeer(stream proto.SignalExchange_ConnectStreamServer) (*peer.Peer, error) {
func (s Server) connectPeer(stream proto.SignalExchange_ConnectStreamServer) (*peer.Peer, error) {
if meta, hasMeta := metadata.FromIncomingContext(stream.Context()); hasMeta {
if id, found := meta[proto.HeaderId]; found {
p := peer.NewPeer(id[0], stream)
Expand Down
13 changes: 13 additions & 0 deletions signal/signal_suite_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
package signal_test

import (
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"

"testing"
)

func TestSignal(t *testing.T) {
RegisterFailHandler(Fail)
RunSpecs(t, "Signal Suite")
}
199 changes: 199 additions & 0 deletions signal/signal_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,199 @@
package signal_test

import (
"context"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
log "github.com/sirupsen/logrus"
"github.com/wiretrustee/wiretrustee/signal"
sigProto "github.com/wiretrustee/wiretrustee/signal/proto"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"google.golang.org/grpc"
"google.golang.org/grpc/keepalive"
"google.golang.org/grpc/metadata"
"net"
"sync"
"time"
)

var _ = Describe("Client", func() {

var (
addr string
listener net.Listener
server *grpc.Server
)

BeforeEach(func() {
server, listener = startSignal()
addr = listener.Addr().String()

})

AfterEach(func() {
server.Stop()
listener.Close()
})

Describe("Exchanging messages", func() {
Context("between connected peers", func() {
It("should be successful", func() {

var msgReceived sync.WaitGroup
msgReceived.Add(2)

var receivedOnA string
var receivedOnB string

// connect PeerA to Signal
keyA, _ := wgtypes.GenerateKey()
clientA := createSignalClient(addr, keyA)
clientA.Receive(func(msg *sigProto.Message) error {
receivedOnA = msg.GetBody().GetPayload()
msgReceived.Done()
return nil
})
clientA.WaitConnected()

// connect PeerB to Signal
keyB, _ := wgtypes.GenerateKey()
clientB := createSignalClient(addr, keyB)
clientB.Receive(func(msg *sigProto.Message) error {
receivedOnB = msg.GetBody().GetPayload()
err := clientB.Send(&sigProto.Message{
Key: keyB.PublicKey().String(),
RemoteKey: keyA.PublicKey().String(),
Body: &sigProto.Body{Payload: "pong"},
})
if err != nil {
Fail("failed sending a message to PeerA")
}
msgReceived.Done()
return nil
})
clientB.WaitConnected()

// PeerA initiates ping-pong
err := clientA.Send(&sigProto.Message{
Key: keyA.PublicKey().String(),
RemoteKey: keyB.PublicKey().String(),
Body: &sigProto.Body{Payload: "ping"},
})
if err != nil {
Fail("failed sending a message to PeerB")
}

if waitTimeout(&msgReceived, 3*time.Second) {
Fail("test timed out on waiting for peers to exchange messages")
}

Expect(receivedOnA).To(BeEquivalentTo("pong"))
Expect(receivedOnB).To(BeEquivalentTo("ping"))

})
})
})

Describe("Connecting to the Signal stream channel", func() {
Context("with a signal client", func() {
It("should be successful", func() {

key, _ := wgtypes.GenerateKey()
client := createSignalClient(addr, key)
client.Receive(func(msg *sigProto.Message) error {
return nil
})
client.WaitConnected()

Expect(client).NotTo(BeNil())
})
})

Context("with a raw client and no ID header", func() {
It("should fail", func() {

client := createRawSignalClient(addr)
stream, err := client.ConnectStream(context.Background())
if err != nil {
Fail("error connecting to stream")
}

_, err = stream.Recv()

Expect(stream).NotTo(BeNil())
Expect(err).NotTo(BeNil())
})
})

Context("with a raw client and with an ID header", func() {
It("should be successful", func() {

md := metadata.New(map[string]string{sigProto.HeaderId: "peer"})
ctx := metadata.NewOutgoingContext(context.Background(), md)

client := createRawSignalClient(addr)
stream, err := client.ConnectStream(ctx)

Expect(stream).NotTo(BeNil())
Expect(err).To(BeNil())
})
})

})

})

func createSignalClient(addr string, key wgtypes.Key) *signal.Client {
client, err := signal.NewClient(context.Background(), addr, key)
if err != nil {
Fail("failed creating signal client")
}
return client
}

func createRawSignalClient(addr string) sigProto.SignalExchangeClient {
ctx := context.Background()
conn, err := grpc.DialContext(ctx, addr, grpc.WithInsecure(),
grpc.WithBlock(),
grpc.WithKeepaliveParams(keepalive.ClientParameters{
Time: 3 * time.Second,
Timeout: 2 * time.Second,
}))
if err != nil {
Fail("failed creating raw signal client")
}

return sigProto.NewSignalExchangeClient(conn)
}

func startSignal() (*grpc.Server, net.Listener) {
lis, err := net.Listen("tcp", ":0")
if err != nil {
panic(err)
}
s := grpc.NewServer()
sigProto.RegisterSignalExchangeServer(s, signal.NewServer())
go func() {
if err := s.Serve(lis); err != nil {
log.Fatalf("failed to serve: %v", err)
}
}()

return s, lis
}

// waitTimeout waits for the waitgroup for the specified max timeout.
// Returns true if waiting timed out.
func waitTimeout(wg *sync.WaitGroup, timeout time.Duration) bool {
c := make(chan struct{})
go func() {
defer close(c)
wg.Wait()
}()
select {
case <-c:
return false // completed normally
case <-time.After(timeout):
return true // timed out
}
}

0 comments on commit 3c45da5

Please sign in to comment.