diff --git a/manager/state/raft/membership/cluster.go b/manager/state/raft/membership/cluster.go index 84c9514066..1e43bd353f 100644 --- a/manager/state/raft/membership/cluster.go +++ b/manager/state/raft/membership/cluster.go @@ -25,6 +25,8 @@ var ( ErrConfigChangeInvalid = errors.New("membership: ConfChange type should be either AddNode, RemoveNode or UpdateNode") // ErrCannotUnmarshalConfig is thrown when a node cannot unmarshal a configuration change ErrCannotUnmarshalConfig = errors.New("membership: cannot unmarshal configuration change") + // ErrMemberRemoved is thrown when a node was removed from the cluster + ErrMemberRemoved = errors.New("raft: member was removed from the cluster") ) // deferredConn used to store removed members connection for some time. diff --git a/manager/state/raft/raft.go b/manager/state/raft/raft.go index 42cf6d40e2..c3ada921a8 100644 --- a/manager/state/raft/raft.go +++ b/manager/state/raft/raft.go @@ -56,8 +56,6 @@ var ( ErrRequestTooLarge = errors.New("raft: raft message is too large and can't be sent") // ErrCannotRemoveMember is thrown when we try to remove a member from the cluster but this would result in a loss of quorum ErrCannotRemoveMember = errors.New("raft: member cannot be removed, because removing it may result in loss of quorum") - // ErrMemberRemoved is thrown when a node was removed from the cluster - ErrMemberRemoved = errors.New("raft: member was removed from the cluster") // ErrNoClusterLeader is thrown when the cluster has no elected leader ErrNoClusterLeader = errors.New("raft: no elected cluster leader") // ErrMemberUnknown is sent in response to a message from an @@ -501,7 +499,7 @@ func (n *Node) Run(ctx context.Context) error { // If the node was removed from other members, // send back an error to the caller to start // the shutdown process. - return ErrMemberRemoved + return membership.ErrMemberRemoved case <-ctx.Done(): return nil } @@ -829,7 +827,7 @@ func (n *Node) ProcessRaftMessage(ctx context.Context, msg *api.ProcessRaftMessa // Don't process the message if this comes from // a node in the remove set if n.cluster.IsIDRemoved(msg.Message.From) { - return nil, ErrMemberRemoved + return nil, membership.ErrMemberRemoved } var sourceHost string @@ -1246,7 +1244,7 @@ func (n *Node) sendToMember(ctx context.Context, members map[uint64]*membership. _, err := api.NewRaftClient(conn.Conn).ProcessRaftMessage(ctx, &api.ProcessRaftMessageRequest{Message: &m}) if err != nil { - if grpc.ErrorDesc(err) == ErrMemberRemoved.Error() { + if grpc.ErrorDesc(err) == membership.ErrMemberRemoved.Error() { n.removeRaftFunc() } if m.Type == raftpb.MsgSnap { diff --git a/manager/state/raft/transport/mock_raft_test.go b/manager/state/raft/transport/mock_raft_test.go new file mode 100644 index 0000000000..59917d94d9 --- /dev/null +++ b/manager/state/raft/transport/mock_raft_test.go @@ -0,0 +1,162 @@ +package transport + +import ( + "net" + "time" + + "golang.org/x/net/context" + + "github.com/coreos/etcd/raft" + "github.com/coreos/etcd/raft/raftpb" + "github.com/docker/swarmkit/api" + "github.com/docker/swarmkit/manager/health" + "github.com/docker/swarmkit/manager/state/raft/membership" + + "google.golang.org/grpc" +) + +type snapshotReport struct { + id uint64 + status raft.SnapshotStatus +} + +type mockRaft struct { + lis net.Listener + s *grpc.Server + tr *Transport + + cancel context.CancelFunc + + nodeRemovedSignal chan struct{} + + removed map[uint64]bool + + processedMessages chan *raftpb.Message + processedSnapshots chan snapshotReport + + reportedUnreachables chan uint64 +} + +func newMockRaft(ctx context.Context) (*mockRaft, error) { + l, err := net.Listen("tcp", "0.0.0.0:0") + if err != nil { + return nil, err + } + ctx, cancel := context.WithCancel(ctx) + mr := &mockRaft{ + lis: l, + s: grpc.NewServer(), + cancel: cancel, + removed: make(map[uint64]bool), + nodeRemovedSignal: make(chan struct{}), + processedMessages: make(chan *raftpb.Message, 4096), + processedSnapshots: make(chan snapshotReport, 4096), + reportedUnreachables: make(chan uint64, 4096), + } + cfg := &Config{ + SendTimeout: 2 * time.Second, + Raft: mr, + } + tr := New(ctx, cfg) + mr.tr = tr + hs := health.NewHealthServer() + hs.SetServingStatus("Raft", api.HealthCheckResponse_SERVING) + api.RegisterRaftServer(mr.s, mr) + api.RegisterHealthServer(mr.s, hs) + go mr.s.Serve(l) + return mr, nil +} + +func (r *mockRaft) Addr() string { + return r.lis.Addr().String() +} + +func (r *mockRaft) Stop() { + r.cancel() + r.s.Stop() +} + +func (r *mockRaft) RemovePeer(id uint64) error { + r.removed[id] = true + return r.tr.RemovePeer(id) +} + +func (r *mockRaft) ProcessRaftMessage(ctx context.Context, req *api.ProcessRaftMessageRequest) (*api.ProcessRaftMessageResponse, error) { + if r.removed[req.Message.From] { + return nil, membership.ErrMemberRemoved + } + r.processedMessages <- req.Message + return &api.ProcessRaftMessageResponse{}, nil +} + +func (r *mockRaft) ResolveAddress(ctx context.Context, req *api.ResolveAddressRequest) (*api.ResolveAddressResponse, error) { + addr, err := r.tr.GetPeerAddr(req.RaftID) + if err != nil { + return nil, err + } + return &api.ResolveAddressResponse{ + Addr: addr, + }, nil +} + +func (r *mockRaft) ReportUnreachable(id uint64) { + r.reportedUnreachables <- id +} + +func (r *mockRaft) IsIDRemoved(id uint64) bool { + return r.removed[id] +} + +func (r *mockRaft) ReportSnapshot(id uint64, status raft.SnapshotStatus) { + r.processedSnapshots <- snapshotReport{ + id: id, + status: status, + } +} + +func (r *mockRaft) NodeRemoved() { + close(r.nodeRemovedSignal) +} + +type mockCluster struct { + rafts map[uint64]*mockRaft + ctx context.Context + cancel context.CancelFunc +} + +func newCluster(ctx context.Context) *mockCluster { + ctx, cancel := context.WithCancel(ctx) + return &mockCluster{ + rafts: make(map[uint64]*mockRaft), + ctx: ctx, + cancel: cancel, + } +} + +func (c *mockCluster) Stop() { + c.cancel() + for _, r := range c.rafts { + r.s.Stop() + } +} + +func (c *mockCluster) Add(id uint64) error { + mr, err := newMockRaft(c.ctx) + if err != nil { + return err + } + for otherID, otherRaft := range c.rafts { + if err := mr.tr.AddPeer(c.ctx, otherID, otherRaft.Addr()); err != nil { + return err + } + if err := otherRaft.tr.AddPeer(c.ctx, id, mr.Addr()); err != nil { + return err + } + } + c.rafts[id] = mr + return nil +} + +func (c *mockCluster) Get(id uint64) *mockRaft { + return c.rafts[id] +} diff --git a/manager/state/raft/transport/peer.go b/manager/state/raft/transport/peer.go new file mode 100644 index 0000000000..6719640095 --- /dev/null +++ b/manager/state/raft/transport/peer.go @@ -0,0 +1,173 @@ +package transport + +import ( + "sync" + "time" + + "golang.org/x/net/context" + + "google.golang.org/grpc" + + "github.com/coreos/etcd/raft" + "github.com/coreos/etcd/raft/raftpb" + "github.com/docker/swarmkit/api" + "github.com/docker/swarmkit/log" + "github.com/docker/swarmkit/manager/state/raft/membership" + "github.com/pkg/errors" +) + +type peer struct { + id uint64 + + tr *Transport + + msgc chan raftpb.Message + + ctx context.Context + cancel context.CancelFunc + done chan struct{} + + mu sync.Mutex + cc *grpc.ClientConn + addr string + + active bool + becameActive time.Time +} + +func newPeer(ctx context.Context, id uint64, addr string, tr *Transport) (*peer, error) { + cc, err := tr.dial(ctx, addr) + if err != nil { + return nil, errors.Wrapf(err, "failed to create conn for %x with addr %s", id, addr) + } + ctx, cancel := context.WithCancel(tr.ctx) + p := &peer{ + id: id, + addr: addr, + cc: cc, + tr: tr, + ctx: ctx, + cancel: cancel, + msgc: make(chan raftpb.Message, 4096), + done: make(chan struct{}), + active: true, + becameActive: time.Now(), + } + go p.run(ctx) + return p, nil +} + +func (p *peer) send(ctx context.Context, m raftpb.Message) (err error) { + defer func() { + if err != nil { + p.mu.Lock() + p.active = false + p.mu.Unlock() + } + }() + select { + case <-ctx.Done(): + return ctx.Err() + case <-p.ctx.Done(): + return ctx.Err() + default: + } + select { + case p.msgc <- m: + case <-ctx.Done(): + return ctx.Err() + case <-p.ctx.Done(): + return ctx.Err() + default: + p.tr.config.Raft.ReportUnreachable(m.To) + return errors.Errorf("peer is unreachable") + } + return nil +} + +func (p *peer) update(ctx context.Context, addr string) error { + p.mu.Lock() + defer p.mu.Unlock() + if addr == p.addr { + return nil + } + cc, err := p.tr.dial(ctx, addr) + if err != nil { + return err + } + p.cc.Close() + p.cc = cc + p.addr = addr + return nil +} + +func (p *peer) conn() *grpc.ClientConn { + p.mu.Lock() + defer p.mu.Unlock() + return p.cc +} + +func (p *peer) address() string { + p.mu.Lock() + defer p.mu.Unlock() + return p.addr +} + +func (p *peer) resolveAddr(ctx context.Context, id uint64) (string, error) { + resp, err := api.NewRaftClient(p.conn()).ResolveAddress(ctx, &api.ResolveAddressRequest{RaftID: id}) + if err != nil { + return "", errors.Wrap(err, "failed to resolve address") + } + return resp.Addr, nil +} + +func (p *peer) sendProcessMessage(ctx context.Context, m raftpb.Message) error { + if _, err := api.NewRaftClient(p.conn()).ProcessRaftMessage(ctx, &api.ProcessRaftMessageRequest{Message: &m}); err != nil { + p.mu.Lock() + p.active = false + p.mu.Unlock() + if m.Type == raftpb.MsgSnap { + p.tr.config.Raft.ReportSnapshot(m.To, raft.SnapshotFailure) + } + p.tr.config.Raft.ReportUnreachable(m.To) + if grpc.ErrorDesc(err) == membership.ErrMemberRemoved.Error() { + p.tr.config.Raft.NodeRemoved() + } + return err + } + if m.Type == raftpb.MsgSnap { + p.tr.config.Raft.ReportSnapshot(m.To, raft.SnapshotFinish) + } + p.mu.Lock() + if !p.active { + p.active = true + p.becameActive = time.Now() + } + p.mu.Unlock() + return nil +} + +func (p *peer) run(ctx context.Context) { + defer func() { + close(p.done) + p.cc.Close() + }() + for { + select { + case <-ctx.Done(): + return + default: + } + select { + case m := <-p.msgc: + ctx, cancel := context.WithTimeout(ctx, p.tr.config.SendTimeout) + err := p.sendProcessMessage(ctx, m) + cancel() + if err != nil { + log.G(ctx).WithError(err).Errorf("failed to send message %s to peer %x", m.Type, m.To) + } + case <-ctx.Done(): + return + } + } +} diff --git a/manager/state/raft/transport/transport.go b/manager/state/raft/transport/transport.go new file mode 100644 index 0000000000..372326faa1 --- /dev/null +++ b/manager/state/raft/transport/transport.go @@ -0,0 +1,305 @@ +package transport + +import ( + "sync" + "time" + + "golang.org/x/net/context" + + "google.golang.org/grpc" + "google.golang.org/grpc/credentials" + + "github.com/coreos/etcd/raft" + "github.com/coreos/etcd/raft/raftpb" + "github.com/docker/swarmkit/api" + "github.com/docker/swarmkit/log" + "github.com/pkg/errors" +) + +// ErrIsNotFound indicates that peer was never added to transport. +var ErrIsNotFound = errors.New("peer not found") + +// Raft is interface which represents Raft API for transport package. +type Raft interface { + ReportUnreachable(id uint64) + ReportSnapshot(id uint64, status raft.SnapshotStatus) + IsIDRemoved(id uint64) bool + + NodeRemoved() +} + +// Config for Transport +type Config struct { + SendTimeout time.Duration + Credentials credentials.TransportCredentials + + Raft Raft +} + +// Transport is structure which manages remote raft peers and sends messages +// to them. +type Transport struct { + config *Config + + unknownc chan raftpb.Message + + mu sync.Mutex + peers map[uint64]*peer + stopped bool + + ctx context.Context + done chan struct{} +} + +// New returns new Transport with specified Context and Config. +// Transport will shutdown on context cancel. +func New(ctx context.Context, cfg *Config) *Transport { + t := &Transport{ + peers: make(map[uint64]*peer), + config: cfg, + unknownc: make(chan raftpb.Message), + done: make(chan struct{}), + ctx: ctx, + } + go t.run(ctx) + go t.unknownSender(ctx) + return t +} + +func (t *Transport) run(ctx context.Context) { + <-ctx.Done() + log.G(ctx).Debug("stop transport") + t.mu.Lock() + defer t.mu.Unlock() + t.stopped = true + for _, p := range t.peers { + p.cancel() + <-p.done + } + close(t.done) +} + +func (t *Transport) sendUnknownMessage(ctx context.Context, m raftpb.Message) error { + ctx, cancel := context.WithTimeout(ctx, 2*t.config.SendTimeout) + defer cancel() + p, err := t.resolvePeer(ctx, m.To) + if err != nil { + return errors.Wrapf(err, "failed to resolve peer") + } + defer p.cancel() + if err := p.sendProcessMessage(ctx, m); err != nil { + return errors.Wrapf(err, "failed to send message") + } + return nil +} + +// unknownSender sends messages to unknown peers. It creates new peer for each +// message and discards it after send. +func (t *Transport) unknownSender(ctx context.Context) { + for { + select { + case <-ctx.Done(): + return + default: + } + select { + case m := <-t.unknownc: + if err := t.sendUnknownMessage(ctx, m); err != nil { + log.G(ctx).WithError(err).Warnf("ignored message %s to unknown peer %x", m.Type, m.To) + } + case <-ctx.Done(): + return + } + } +} + +// Done returns channel which will be closed when transport is entirely shutdown. +func (t *Transport) Done() chan struct{} { + return t.done +} + +// AddPeer adds new peer with id and address addr to Transport. +// If there is already peer with such id in Transport it will return error if +// addres is differ(UpdatePeer should be used) or nil otherwise. +func (t *Transport) AddPeer(ctx context.Context, id uint64, addr string) error { + t.mu.Lock() + defer t.mu.Unlock() + if t.stopped { + return errors.New("transport stopped") + } + if ep, ok := t.peers[id]; ok { + if ep.address() == addr { + return nil + } + return errors.Errorf("peer %x already added with addr %s", id, ep.addr) + } + p, err := newPeer(ctx, id, addr, t) + if err != nil { + return errors.Wrapf(err, "failed to create peer %x with addr %s", id, addr) + } + t.peers[id] = p + return nil +} + +// RemovePeer removes peer from Transport and wait for it to stop. +func (t *Transport) RemovePeer(id uint64) error { + t.mu.Lock() + defer t.mu.Unlock() + if t.stopped { + return errors.New("transport stopped") + } + p, ok := t.peers[id] + if !ok { + return ErrIsNotFound + } + p.cancel() + <-p.done + delete(t.peers, id) + return nil +} + +// UpdatePeer updates peer with new address. +func (t *Transport) UpdatePeer(ctx context.Context, id uint64, addr string) error { + t.mu.Lock() + defer t.mu.Unlock() + if t.stopped { + return errors.New("transport stopped") + } + p, ok := t.peers[id] + if !ok { + return ErrIsNotFound + } + return p.update(ctx, addr) +} + +// GetPeerAddr returns address of peer with id. +func (t *Transport) GetPeerAddr(id uint64) (string, error) { + t.mu.Lock() + defer t.mu.Unlock() + p, ok := t.peers[id] + if !ok { + return "", ErrIsNotFound + } + return p.address(), nil +} + +func (t *Transport) resolvePeer(ctx context.Context, id uint64) (*peer, error) { + longestActive, err := t.longestActive() + if err != nil { + return nil, err + } + ctx, cancel := context.WithTimeout(ctx, 2*t.config.SendTimeout) + defer cancel() + addr, err := longestActive.resolveAddr(ctx, id) + if err != nil { + return nil, err + } + return newPeer(ctx, id, addr, t) +} + +// Send sends raft message to remote peers. +func (t *Transport) Send(ctx context.Context, m raftpb.Message) error { + t.mu.Lock() + defer t.mu.Unlock() + if t.stopped { + return errors.New("transport stopped") + } + if t.config.Raft.IsIDRemoved(m.To) { + return errors.Errorf("refusing to send message %s to removed member %x", m.Type, m.To) + } + p, ok := t.peers[m.To] + if !ok { + log.G(ctx).Warningf("sending message %s to an unrecognized member ID %x", m.Type, m.To) + select { + // we need to process messages to unknown peers in separate goroutine + // to not block sender + case t.unknownc <- m: + case <-ctx.Done(): + return ctx.Err() + case <-t.ctx.Done(): + return ctx.Err() + default: + return errors.New("unknown messages queue is full") + } + return nil + } + if err := p.send(ctx, m); err != nil { + return errors.Wrapf(err, "failed to send message %x to %x", m.Type, m.To) + } + return nil +} + +func (t *Transport) longestActive() (*peer, error) { + var longest *peer + var longestTime time.Time + for _, p := range t.peers { + p.mu.Lock() + active, becameActive := p.active, p.becameActive + p.mu.Unlock() + if !active { + continue + } + if longest == nil { + longest = p + continue + } + if becameActive.Before(longestTime) { + longest = p + longestTime = becameActive + } + } + if longest == nil { + return nil, errors.New("failed to find longest active peer") + } + return longest, nil +} + +func (t *Transport) dial(ctx context.Context, addr string) (*grpc.ClientConn, error) { + grpcOptions := []grpc.DialOption{ + grpc.WithBackoffMaxDelay(2 * time.Second), + grpc.WithBlock(), + } + if t.config.Credentials != nil { + grpcOptions = append(grpcOptions, grpc.WithTransportCredentials(t.config.Credentials)) + } else { + grpcOptions = append(grpcOptions, grpc.WithInsecure()) + } + + if t.config.SendTimeout > 0 { + grpcOptions = append(grpcOptions, grpc.WithTimeout(t.config.SendTimeout)) + } + + cc, err := grpc.Dial(addr, grpcOptions...) + if err != nil { + return nil, err + } + + ctx, cancel := t.withContext(ctx) + defer cancel() + ctx, cancel = context.WithTimeout(ctx, t.config.SendTimeout) + defer cancel() + resp, err := api.NewHealthClient(cc).Check(ctx, &api.HealthCheckRequest{Service: "Raft"}) + if err != nil { + cc.Close() + return nil, errors.Wrap(err, "failed to check health") + } + if resp.Status != api.HealthCheckResponse_SERVING { + cc.Close() + return nil, errors.Errorf("health check returned status %s", resp.Status) + } + + return cc, nil +} + +func (t *Transport) withContext(ctx context.Context) (context.Context, context.CancelFunc) { + ctx, cancel := context.WithCancel(ctx) + + go func() { + select { + case <-ctx.Done(): + case <-t.ctx.Done(): + cancel() + } + }() + return ctx, cancel +} diff --git a/manager/state/raft/transport/transport_test.go b/manager/state/raft/transport/transport_test.go new file mode 100644 index 0000000000..1ed96081de --- /dev/null +++ b/manager/state/raft/transport/transport_test.go @@ -0,0 +1,259 @@ +package transport + +import ( + "fmt" + "net" + "testing" + "time" + + "golang.org/x/net/context" + + "github.com/coreos/etcd/raft" + "github.com/coreos/etcd/raft/raftpb" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func sendMessages(ctx context.Context, c *mockCluster, from uint64, to []uint64, msgType raftpb.MessageType) error { + var firstErr error + for _, id := range to { + err := c.Get(from).tr.Send(ctx, raftpb.Message{ + Type: msgType, + From: from, + To: id, + }) + if firstErr == nil { + firstErr = err + } + } + return firstErr +} + +func testSend(ctx context.Context, c *mockCluster, from uint64, to []uint64, msgType raftpb.MessageType) func(*testing.T) { + return func(t *testing.T) { + ctx, cancel := context.WithTimeout(ctx, 4*time.Second) + defer cancel() + require.NoError(t, sendMessages(ctx, c, from, to, msgType)) + + for _, id := range to { + select { + case msg := <-c.Get(id).processedMessages: + assert.Equal(t, msg.To, id) + assert.Equal(t, msg.From, from) + case <-ctx.Done(): + t.Fatal(ctx.Err()) + } + } + + if msgType == raftpb.MsgSnap { + var snaps []snapshotReport + for i := 0; i < len(to); i++ { + select { + case snap := <-c.Get(from).processedSnapshots: + snaps = append(snaps, snap) + case <-ctx.Done(): + t.Fatal(ctx.Err()) + } + } + loop: + for _, id := range to { + for _, s := range snaps { + if s.id == id { + assert.Equal(t, s.status, raft.SnapshotFinish) + continue loop + } + } + t.Fatalf("shapshot ot %d is not reported", id) + } + } + } +} + +func TestSend(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + c := newCluster(ctx) + defer func() { + cancel() + c.Stop() + }() + require.NoError(t, c.Add(1)) + require.NoError(t, c.Add(2)) + require.NoError(t, c.Add(3)) + + t.Run("Send Message", testSend(ctx, c, 1, []uint64{2, 3}, raftpb.MsgHup)) + t.Run("Send_Snapshot_Message", testSend(ctx, c, 1, []uint64{2, 3}, raftpb.MsgSnap)) +} + +func TestSendRemoved(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + c := newCluster(ctx) + defer func() { + cancel() + c.Stop() + }() + require.NoError(t, c.Add(1)) + require.NoError(t, c.Add(2)) + require.NoError(t, c.Add(3)) + require.NoError(t, c.Get(1).RemovePeer(2)) + + err := sendMessages(ctx, c, 1, []uint64{2, 3}, raftpb.MsgHup) + require.Error(t, err) + require.Contains(t, err.Error(), "to removed member") +} + +func TestSendSnapshotFailure(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + c := newCluster(ctx) + defer func() { + cancel() + c.Stop() + }() + require.NoError(t, c.Add(1)) + require.NoError(t, c.Add(2)) + + // stop peer server to emulate error + c.Get(2).s.Stop() + + msgCtx, msgCancel := context.WithTimeout(ctx, 4*time.Second) + defer msgCancel() + + require.NoError(t, sendMessages(msgCtx, c, 1, []uint64{2}, raftpb.MsgSnap)) + + select { + case snap := <-c.Get(1).processedSnapshots: + assert.Equal(t, snap.id, uint64(2)) + assert.Equal(t, snap.status, raft.SnapshotFailure) + case <-msgCtx.Done(): + t.Fatal(ctx.Err()) + } + + select { + case id := <-c.Get(1).reportedUnreachables: + assert.Equal(t, id, uint64(2)) + case <-msgCtx.Done(): + t.Fatal(ctx.Err()) + } +} + +func TestSendUnknown(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + c := newCluster(ctx) + defer func() { + cancel() + c.Stop() + }() + require.NoError(t, c.Add(1)) + require.NoError(t, c.Add(2)) + require.NoError(t, c.Add(3)) + + // remove peeer from 1 transport to make it "unknown" to it + oldPeer := c.Get(1).tr.peers[2] + delete(c.Get(1).tr.peers, 2) + oldPeer.cancel() + <-oldPeer.done + + msgCtx, msgCancel := context.WithTimeout(ctx, 4*time.Second) + defer msgCancel() + + require.NoError(t, sendMessages(msgCtx, c, 1, []uint64{2}, raftpb.MsgHup)) + + select { + case msg := <-c.Get(2).processedMessages: + assert.Equal(t, msg.To, uint64(2)) + assert.Equal(t, msg.From, uint64(1)) + case <-msgCtx.Done(): + t.Fatal(msgCtx.Err()) + } +} + +func TestUpdatePeer(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + c := newCluster(ctx) + defer func() { + cancel() + c.Stop() + }() + require.NoError(t, c.Add(1)) + require.NoError(t, c.Add(2)) + require.NoError(t, c.Add(3)) + + t.Run("Send Message Before Address Update", testSend(ctx, c, 1, []uint64{2, 3}, raftpb.MsgHup)) + + nr, err := newMockRaft(ctx) + require.NoError(t, err) + + c.Get(3).Stop() + c.rafts[3] = nr + + require.NoError(t, c.Get(1).tr.UpdatePeer(ctx, 3, nr.Addr())) + require.NoError(t, c.Get(1).tr.UpdatePeer(ctx, 3, nr.Addr())) + + t.Run("Send Message After Address Update", testSend(ctx, c, 1, []uint64{2, 3}, raftpb.MsgHup)) +} + +func TestSendUnreachable(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + c := newCluster(ctx) + defer func() { + cancel() + c.Stop() + }() + require.NoError(t, c.Add(1)) + require.NoError(t, c.Add(2)) + + // set channel to nil to emulate full queue + c.Get(1).tr.peers[2].msgc = nil + + msgCtx, msgCancel := context.WithTimeout(ctx, 4*time.Second) + defer msgCancel() + + err := sendMessages(msgCtx, c, 1, []uint64{2}, raftpb.MsgSnap) + require.Error(t, err) + require.Contains(t, err.Error(), "peer is unreachable") + select { + case id := <-c.Get(1).reportedUnreachables: + assert.Equal(t, id, uint64(2)) + case <-msgCtx.Done(): + t.Fatal(ctx.Err()) + } +} + +func TestSendNodeRemoved(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + c := newCluster(ctx) + defer func() { + cancel() + c.Stop() + }() + require.NoError(t, c.Add(1)) + require.NoError(t, c.Add(2)) + + require.NoError(t, c.Get(1).RemovePeer(2)) + + msgCtx, msgCancel := context.WithTimeout(ctx, 4*time.Second) + defer msgCancel() + + require.NoError(t, sendMessages(msgCtx, c, 2, []uint64{1}, raftpb.MsgSnap)) + select { + case <-c.Get(2).nodeRemovedSignal: + case <-msgCtx.Done(): + t.Fatal(ctx.Err()) + } +} + +func TestAddPeerNotRaft(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + c := newCluster(ctx) + defer func() { + cancel() + c.Stop() + }() + require.NoError(t, c.Add(1)) + l, err := net.Listen("tcp", "0.0.0.0:0") + require.NoError(t, err) + + err = c.Get(1).tr.AddPeer(ctx, 2, l.Addr().String()) + fmt.Printf("%v\n", err) + require.Error(t, err) + require.Contains(t, err.Error(), "failed to check health") +} diff --git a/vendor/google.golang.org/grpc/transport/transport.go b/vendor/google.golang.org/grpc/transport/transport.go index 3d6b6a6d51..8f2921fdc2 100644 --- a/vendor/google.golang.org/grpc/transport/transport.go +++ b/vendor/google.golang.org/grpc/transport/transport.go @@ -539,6 +539,8 @@ func ContextErr(err error) StreamError { case context.Canceled: return streamErrorf(codes.Canceled, "%v", err) } + fmt.Printf("%T %v\n", err, err) + fmt.Printf("%T %v\n", context.Canceled, context.Canceled) panic(fmt.Sprintf("Unexpected error from context packet: %v", err)) }