diff --git a/signal/client.go b/signal/client.go index 43f82b5c3f9..f0f3b95b266 100644 --- a/signal/client.go +++ b/signal/client.go @@ -29,7 +29,7 @@ type Client struct { ctx context.Context stream proto.SignalExchange_ConnectStreamClient //waiting group to notify once stream is connected - connWg sync.WaitGroup //todo use a channel instead?? + connWg *sync.WaitGroup //todo use a channel instead?? } // Close Closes underlying connections to the Signal Exchange @@ -55,11 +55,13 @@ func NewClient(ctx context.Context, addr string, key wgtypes.Key) (*Client, erro return nil, err } + var wg sync.WaitGroup return &Client{ realClient: proto.NewSignalExchangeClient(conn), ctx: ctx, signalConn: conn, key: key, + connWg: &wg, }, nil } @@ -107,8 +109,8 @@ func (c *Client) connect(key string, msgHandler func(msg *proto.Message) error) // add key fingerprint to the request header to be identified on the server side md := metadata.New(map[string]string{proto.HeaderId: key}) ctx := metadata.NewOutgoingContext(c.ctx, md) - ctx, cancel := context.WithCancel(ctx) - defer cancel() + //ctx, cancel := context.WithCancel(ctx) + //defer cancel() stream, err := c.realClient.ConnectStream(ctx) diff --git a/signal/signal.go b/signal/signal.go index a75c2050ef8..2c3101e2e7b 100644 --- a/signal/signal.go +++ b/signal/signal.go @@ -53,7 +53,6 @@ func (s *SignalExchangeServer) ConnectStream(stream proto.SignalExchange_Connect } log.Infof("peer [%s] has successfully connected", p.Id) - for { msg, err := stream.Recv() if err == io.EOF { diff --git a/signal/signal_test.go b/signal/signal_test.go index c9ffd5c0c6f..b43bba12947 100644 --- a/signal/signal_test.go +++ b/signal/signal_test.go @@ -66,7 +66,7 @@ var _ = Describe("Client", func() { Body: &sigProto.Body{Payload: "pong"}, }) if err != nil { - Fail("failed sending a message to {PeerA}") + Fail("failed sending a message to PeerA") } msgReceived.Done() return nil @@ -83,7 +83,9 @@ var _ = Describe("Client", func() { Fail("failed sending a message to PeerB") } - msgReceived.Wait() + if waitTimeout(&msgReceived, 3*time.Second) { + Fail("test timed out on waiting for peers to exchange messages") + } Expect(receivedOnA).To(BeEquivalentTo("pong")) Expect(receivedOnB).To(BeEquivalentTo("ping")) @@ -179,3 +181,19 @@ func startSignal() (*grpc.Server, net.Listener) { return s, lis } + +// waitTimeout waits for the waitgroup for the specified max timeout. +// Returns true if waiting timed out. +func waitTimeout(wg *sync.WaitGroup, timeout time.Duration) bool { + c := make(chan struct{}) + go func() { + defer close(c) + wg.Wait() + }() + select { + case <-c: + return false // completed normally + case <-time.After(timeout): + return true // timed out + } +}