diff --git a/gossip/comm/comm_impl.go b/gossip/comm/comm_impl.go new file mode 100644 index 00000000000..dce254694aa --- /dev/null +++ b/gossip/comm/comm_impl.go @@ -0,0 +1,539 @@ +/* +Copyright IBM Corp. 2016 All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package comm + +import ( + "bytes" + "fmt" + "math/rand" + "net" + "sync" + "sync/atomic" + "time" + + "crypto/tls" + "os" + + "github.com/hyperledger/fabric/gossip/proto" + "github.com/hyperledger/fabric/gossip/util" + "github.com/op/go-logging" + "golang.org/x/net/context" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials" + "google.golang.org/grpc/peer" +) + +const ( + defDialTimeout = time.Second * time.Duration(3) + defConnTimeout = time.Second * time.Duration(2) + defRecvBuffSize = 100 +) + +var dialTimeout = defDialTimeout + +func init() { + rand.Seed(42) +} + +// SetDialTimeout sets the dial timeout +func SetDialTimeout(timeout time.Duration) { + dialTimeout = timeout +} + +func (c *commImpl) SetDialOpts(opts ...grpc.DialOption) { + c.opts = opts +} + +// NewCommInstanceWithServer creates a comm instance that creates an underlying gRPC server +func NewCommInstanceWithServer(port int, sec SecurityProvider, pkID PKIidType, dialOpts ...grpc.DialOption) (Comm, error) { + var ll net.Listener + var s *grpc.Server + var secOpt grpc.DialOption + + if len(dialOpts) == 0 { + dialOpts = []grpc.DialOption{grpc.WithTimeout(dialTimeout)} + } + + if port > 0 { + s, ll, secOpt = createGRPCLayer(port) + dialOpts = append(dialOpts, secOpt) + } + + commInst := &commImpl{ + logger: util.GetLogger(util.LOGGING_COMM_MODULE, fmt.Sprintf("%d", port)), + PKIID: pkID, + opts: dialOpts, + sec: sec, + port: port, + lsnr: ll, + gSrv: s, + msgPublisher: NewChannelDemultiplexer(), + lock: &sync.RWMutex{}, + deadEndpoints: make(chan PKIidType, 100), + stopping: int32(0), + exitChan: make(chan struct{}, 1), + subscriptions: make([]chan ReceivedMessage, 0), + blackListedPKIIDs: make([]PKIidType, 0), + } + commInst.connStore = newConnStore(commInst, pkID, commInst.logger) + + if port > 0 { + go func() { + commInst.stopWG.Add(1) + defer commInst.stopWG.Done() + s.Serve(ll) + }() + } + + proto.RegisterGossipServer(s, commInst) + + commInst.logger.SetLevel(logging.WARNING) + + time.Sleep(time.Duration(200) * time.Millisecond) + + return commInst, nil +} + +// NewCommInstance creates a new comm instance that binds itself to the given gRPC server +func NewCommInstance(s *grpc.Server, sec SecurityProvider, PKIID PKIidType, dialOpts ...grpc.DialOption) (Comm, error) { + commInst, err := NewCommInstanceWithServer(-1, sec, PKIID) + if err != nil { + return nil, err + } + proto.RegisterGossipServer(s, commInst.(*commImpl)) + return commInst, nil +} + +type commImpl struct { + logger *util.Logger + sec SecurityProvider + opts []grpc.DialOption + connStore *connectionStore + PKIID []byte + port int + deadEndpoints chan PKIidType + msgPublisher *ChannelDeMultiplexer + lock *sync.RWMutex + lsnr net.Listener + gSrv *grpc.Server + exitChan chan struct{} + stopping int32 + stopWG sync.WaitGroup + subscriptions []chan ReceivedMessage + blackListedPKIIDs []PKIidType +} + +func (c *commImpl) createConnection(endpoint string, expectedPKIID PKIidType) (*connection, error) { + c.logger.Debug("Entering", endpoint, expectedPKIID) + defer c.logger.Debug("Exiting") + if c.isStopping() { + return nil, fmt.Errorf("Stopping") + } + cc, err := grpc.Dial(endpoint, append(c.opts, grpc.WithBlock())...) + if err != nil { + if cc != nil { + cc.Close() + } + return nil, err + } + + cl := proto.NewGossipClient(cc) + + if _, err := cl.Ping(context.Background(), &proto.Empty{}); err != nil { + cc.Close() + return nil, err + } + + if stream, err := cl.GossipStream(context.Background()); err == nil { + pkiID, err := c.authenticateRemotePeer(stream) + if expectedPKIID != nil && !bytes.Equal(pkiID, expectedPKIID) { + // PKIID is nil when we don't know the remote PKI id's + c.logger.Warning("Remote endpoint claims to be a different peer, expected", expectedPKIID, "but got", pkiID) + return nil, fmt.Errorf("Authentication failure") + } + if err == nil { + conn := newConnection(cl, cc, stream, nil) + conn.pkiID = pkiID + conn.logger = c.logger + + h := func(m *proto.GossipMessage) { + c.logger.Debug("Got message:", m) + c.msgPublisher.DeMultiplex(&ReceivedMessageImpl{ + conn: conn, + lock: conn, + GossipMessage: m, + }) + } + conn.handler = h + return conn, nil + } + return nil, fmt.Errorf("Authentication failure") + } + cc.Close() + return nil, err +} + +func (c *commImpl) Send(msg *proto.GossipMessage, peers ...*RemotePeer) { + if c.isStopping() { + return + } + + c.logger.Info("Entering, sending", msg, "to ", len(peers), "peers") + + for _, peer := range peers { + // TODO: create outgoing buffers and flow control per connection + go func(peer *RemotePeer, msg *proto.GossipMessage) { + c.sendToEndpoint(peer, msg) + }(peer, msg) + } +} + +func (c *commImpl) BlackListPKIid(PKIID PKIidType) { + c.logger.Info("Entering", PKIID) + defer c.logger.Info("Exiting") + c.lock.Lock() + defer c.lock.Unlock() + c.connStore.closeByPKIid(PKIID) + c.blackListedPKIIDs = append(c.blackListedPKIIDs, PKIID) +} + +func (c *commImpl) isPKIblackListed(p PKIidType) bool { + c.lock.RLock() + defer c.lock.RUnlock() + for _, pki := range c.blackListedPKIIDs { + if bytes.Equal(pki, p) { + c.logger.Debug(p, ":", true) + return true + } + } + c.logger.Debug(p, ":", false) + return false +} + +func (c *commImpl) sendToEndpoint(peer *RemotePeer, msg *proto.GossipMessage) error { + if c.isStopping() { + return nil + } + c.logger.Debug("Entering, Sending to", peer.Endpoint, ", msg:", msg) + defer c.logger.Debug("Exiting") + var err error + + conn, err := c.connStore.getConnection(peer) + if err == nil { + t1 := time.Now() + err = conn.send(msg) + if err != nil { + c.logger.Warning(peer, "isn't responsive:", err) + c.disconnect(peer.PKIID) + return err + } + c.logger.Debug("Send took", time.Since(t1)) + return nil + } + c.logger.Warning("Failed obtaining connection for", peer, "reason:", err) + c.disconnect(peer.PKIID) + return err +} + +func (c *commImpl) isStopping() bool { + return atomic.LoadInt32(&c.stopping) == int32(1) +} + +func (c *commImpl) Probe(endpoint string, pkiID PKIidType) error { + if c.isStopping() { + return fmt.Errorf("Stopping!") + } + c.logger.Debug("Entering, endpoint:", endpoint, "PKIID:", pkiID) + var err error + + opts := c.opts + if opts == nil { + opts = []grpc.DialOption{grpc.WithInsecure(), grpc.WithTimeout(dialTimeout)} + } + cc, err := grpc.Dial(endpoint, append(opts, grpc.WithBlock())...) + if err != nil { + c.logger.Debug("Returning", err) + return err + } + defer cc.Close() + cl := proto.NewGossipClient(cc) + _, err = cl.Ping(context.Background(), &proto.Empty{}) + c.logger.Debug("Returning", err) + return err +} + +func (c *commImpl) Accept(acceptor util.MessageAcceptor) <-chan ReceivedMessage { + genericChan := c.msgPublisher.AddChannel(acceptor) + specificChan := make(chan ReceivedMessage, 10) + + if c.isStopping() { + c.logger.Warning("Accept() called but comm module is stopping, returning empty channel") + return specificChan + } + + c.lock.Lock() + c.subscriptions = append(c.subscriptions, specificChan) + c.lock.Unlock() + + go func() { + defer c.logger.Debug("Exiting Accept() loop") + defer func() { + c.logger.Warning("Recovered") + recover() + }() + + c.stopWG.Add(1) + defer c.stopWG.Done() + + for { + select { + case msg := <-genericChan: + specificChan <- msg.(*ReceivedMessageImpl) + break + case s := <-c.exitChan: + c.exitChan <- s + return + } + } + }() + return specificChan +} + +func (c *commImpl) PresumedDead() <-chan PKIidType { + return c.deadEndpoints +} + +func (c *commImpl) CloseConn(peer *RemotePeer) { + c.logger.Info("Closing connection for", peer) + c.connStore.closeConn(peer) +} + +func (c *commImpl) emptySubscriptions() { + c.lock.Lock() + defer c.lock.Unlock() + for _, ch := range c.subscriptions { + close(ch) + } +} + +func (c *commImpl) Stop() { + if c.isStopping() { + return + } + atomic.StoreInt32(&c.stopping, int32(1)) + c.logger.Info("Stopping") + defer c.logger.Info("Stopped") + if c.gSrv != nil { + c.gSrv.Stop() + } + if c.lsnr != nil { + c.lsnr.Close() + } + c.connStore.shutdown() + c.logger.Debug("Shut down connection store, connection count:", c.connStore.connNum()) + c.exitChan <- struct{}{} + c.msgPublisher.Close() + c.logger.Debug("Shut down publisher") + c.emptySubscriptions() + c.logger.Debug("Closed subscriptions, waiting for goroutines to stop...") + c.stopWG.Wait() +} + +func (c *commImpl) GetPKIid() PKIidType { + return c.PKIID +} + +func extractRemoteAddress(stream stream) string { + var remoteAddress string + p, ok := peer.FromContext(stream.Context()) + if ok { + if address := p.Addr; address != nil { + remoteAddress = address.String() + } + } + return remoteAddress +} + +func (c *commImpl) authenticateRemotePeer(stream stream) (PKIidType, error) { + ctx := stream.Context() + remoteAddress := extractRemoteAddress(stream) + tlsUnique := ExtractTLSUnique(ctx) + var sig []byte + var err error + if tlsUnique != nil && c.sec.IsEnabled() { + sig, err = c.sec.Sign(tlsUnique) + if err != nil { + c.logger.Error("Failed signing TLS-Unique:", err) + return nil, err + } + } + + cMsg := createConnectionMsg(c.PKIID, sig) + stream.Send(cMsg) + m := readWithTimeout(stream, defConnTimeout) + if m == nil { + c.logger.Warning("Timed out waiting for connection message from", remoteAddress) + return nil, fmt.Errorf("Timed out") + } + connMsg := m.GetConn() + if connMsg == nil { + c.logger.Warning("Expected connection message but got", connMsg) + return nil, fmt.Errorf("Wrong type") + } + if c.isPKIblackListed(connMsg.PkiID) { + c.logger.Warning("Connection attempt from", remoteAddress, "but it is black-listed") + return nil, fmt.Errorf("Black-listed") + } + + if tlsUnique != nil && c.sec.IsEnabled() { + err = c.sec.Verify(connMsg.PkiID, connMsg.Sig, tlsUnique) + if err != nil { + c.logger.Error("Failed verifying signature from", remoteAddress, ":", err) + return nil, err + } + } + + if connMsg.PkiID == nil { + return nil, fmt.Errorf("%s didn't send a pkiID", "Didn't send a pkiID") + } + + c.logger.Debug("Authenticated", remoteAddress) + return connMsg.PkiID, nil + +} + +func (c *commImpl) GossipStream(stream proto.Gossip_GossipStreamServer) error { + if c.isStopping() { + return fmt.Errorf("Shutting down") + } + PKIID, err := c.authenticateRemotePeer(stream) + if err != nil { + c.logger.Error("Authentication failed") + return err + } + c.logger.Info("Servicing", extractRemoteAddress(stream)) + + conn := c.connStore.onConnected(stream, PKIID) + + // if connStore denied the connection, it means we already have a connection to that peer + // so close this stream + if conn == nil { + return nil + } + + h := func(m *proto.GossipMessage) { + c.msgPublisher.DeMultiplex(&ReceivedMessageImpl{ + conn: conn, + lock: conn, + GossipMessage: m, + }) + } + + conn.handler = h + + defer func() { + c.logger.Info("Client", extractRemoteAddress(stream), " disconnected") + c.connStore.closeByPKIid(PKIID) + }() + + return conn.serviceInput() +} + +func (c *commImpl) Ping(context.Context, *proto.Empty) (*proto.Empty, error) { + return &proto.Empty{}, nil +} + +func (c *commImpl) disconnect(pkiID PKIidType) { + c.deadEndpoints <- pkiID + c.connStore.closeByPKIid(pkiID) +} + +func readWithTimeout(stream interface{}, timeout time.Duration) *proto.GossipMessage { + incChan := make(chan *proto.GossipMessage, 1) + go func() { + if srvStr, isServerStr := stream.(proto.Gossip_GossipStreamServer); isServerStr { + if m, err := srvStr.Recv(); err == nil { + incChan <- m + } + } + if clStr, isClientStr := stream.(proto.Gossip_GossipStreamClient); isClientStr { + if m, err := clStr.Recv(); err == nil { + incChan <- m + } + } + }() + select { + case <-time.NewTicker(timeout).C: + return nil + case m := <-incChan: + return m + } +} + +func createConnectionMsg(pkiID PKIidType, sig []byte) *proto.GossipMessage { + return &proto.GossipMessage{ + Nonce: 0, + Content: &proto.GossipMessage_Conn{ + Conn: &proto.ConnEstablish{ + PkiID: pkiID, + Sig: sig, + }, + }, + } +} + +type stream interface { + Send(*proto.GossipMessage) error + Recv() (*proto.GossipMessage, error) + grpc.Stream +} + +func createGRPCLayer(port int) (*grpc.Server, net.Listener, grpc.DialOption) { + var s *grpc.Server + var ll net.Listener + var err error + var serverOpts []grpc.ServerOption + var dialOpts grpc.DialOption + + keyFileName := fmt.Sprintf("key.%d.pem", time.Now().UnixNano()) + certFileName := fmt.Sprintf("cert.%d.pem", time.Now().UnixNano()) + + defer os.Remove(keyFileName) + defer os.Remove(certFileName) + + err = generateCertificates(keyFileName, certFileName) + if err == nil { + var creds credentials.TransportCredentials + creds, err = credentials.NewServerTLSFromFile(certFileName, keyFileName) + serverOpts = append(serverOpts, grpc.Creds(creds)) + ta := credentials.NewTLS(&tls.Config{ + InsecureSkipVerify: true, + }) + dialOpts = grpc.WithTransportCredentials(&authCreds{tlsCreds: ta}) + } else { + dialOpts = grpc.WithInsecure() + } + + listenAddress := fmt.Sprintf("%s:%d", "", port) + ll, err = net.Listen("tcp", listenAddress) + if err != nil { + panic(err) + } + + s = grpc.NewServer(serverOpts...) + return s, ll, dialOpts +} diff --git a/gossip/comm/comm_test.go b/gossip/comm/comm_test.go new file mode 100644 index 00000000000..99c7014c0f6 --- /dev/null +++ b/gossip/comm/comm_test.go @@ -0,0 +1,460 @@ +/* +Copyright IBM Corp. 2016 All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package comm + +import ( + "bytes" + "fmt" + "math/rand" + "sync" + "testing" + "time" + + "crypto/tls" + + "github.com/hyperledger/fabric/gossip/proto" + "github.com/stretchr/testify/assert" + "golang.org/x/net/context" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials" +) + +func init() { + rand.Seed(42) + SetDialTimeout(time.Duration(300) * time.Millisecond) +} + +func acceptAll(msg interface{}) bool { + return true +} + +var naiveSec = &naiveSecProvider{} + +type naiveSecProvider struct { +} + +func (*naiveSecProvider) IsEnabled() bool { + return true +} + +func (*naiveSecProvider) Sign(msg []byte) ([]byte, error) { + return msg, nil +} + +func (*naiveSecProvider) Verify(vkID, signature, message []byte) error { + if bytes.Equal(signature, message) { + return nil + } + return fmt.Errorf("Failed verifying") +} + +func newCommInstance(port int, sec SecurityProvider) (Comm, error) { + endpoint := fmt.Sprintf("localhost:%d", port) + inst, err := NewCommInstanceWithServer(port, sec, []byte(endpoint)) + return inst, err +} + +func TestHandshake(t *testing.T) { + comm1, _ := newCommInstance(9611, naiveSec) + defer comm1.Stop() + + ta := credentials.NewTLS(&tls.Config{ + InsecureSkipVerify: true, + }) + conn, err := grpc.Dial("localhost:9611", grpc.WithTransportCredentials(&authCreds{tlsCreds: ta}), grpc.WithBlock(), grpc.WithTimeout(time.Second)) + assert.NoError(t, err, "%v", err) + if err != nil { + return + } + cl := proto.NewGossipClient(conn) + stream, err := cl.GossipStream(context.Background()) + assert.NoError(t, err, "%v", err) + if err != nil { + return + } + + // happy path + clientTLSUnique := ExtractTLSUnique(stream.Context()) + sig, err := naiveSec.Sign(clientTLSUnique) + assert.NoError(t, err, "%v", err) + msg := createConnectionMsg(PKIidType("localhost:9610"), sig) + stream.Send(msg) + msg, err = stream.Recv() + assert.NoError(t, err, "%v", err) + assert.Equal(t, clientTLSUnique, msg.GetConn().Sig) + assert.Equal(t, []byte("localhost:9611"), msg.GetConn().PkiID) + time.Sleep(time.Second) + msg2Send := createGossipMsg() + nonce := uint64(rand.Int()) + msg2Send.Nonce = nonce + rcvChan := make(chan *proto.GossipMessage, 1) + go func() { + m := <-comm1.Accept(acceptAll) + rcvChan <- m.GetGossipMessage() + }() + stream.Send(msg2Send) + time.Sleep(time.Second) + assert.Equal(t, 1, len(rcvChan)) + receivedMsg := <-rcvChan + assert.Equal(t, nonce, receivedMsg.Nonce) + + // negative path, nothing should be read from the channel because the signature is wrong + rcvChan = make(chan *proto.GossipMessage, 1) + go func() { + m := <-comm1.Accept(acceptAll) + if m == nil { + return + } + rcvChan <- m.GetGossipMessage() + }() + conn, err = grpc.Dial("localhost:9611", grpc.WithTransportCredentials(&authCreds{tlsCreds: ta}), grpc.WithBlock(), grpc.WithTimeout(time.Second)) + assert.NoError(t, err, "%v", err) + if err != nil { + return + } + cl = proto.NewGossipClient(conn) + stream, err = cl.GossipStream(context.Background()) + assert.NoError(t, err, "%v", err) + if err != nil { + return + } + clientTLSUnique = ExtractTLSUnique(stream.Context()) + sig, err = naiveSec.Sign(clientTLSUnique) + assert.NoError(t, err, "%v", err) + // ruin the signature + if sig[0] == 0 { + sig[0] = 1 + } else { + sig[0] = 0 + } + msg = createConnectionMsg(PKIidType("localhost:9612"), sig) + stream.Send(msg) + msg, err = stream.Recv() + assert.Equal(t, []byte("localhost:9611"), msg.GetConn().PkiID) + assert.NoError(t, err, "%v", err) + msg2Send = createGossipMsg() + nonce = uint64(rand.Int()) + msg2Send.Nonce = nonce + stream.Send(msg2Send) + time.Sleep(time.Second) + assert.Equal(t, 0, len(rcvChan)) +} + +func TestBasic(t *testing.T) { + comm1, _ := newCommInstance(2000, naiveSec) + comm2, _ := newCommInstance(3000, naiveSec) + defer comm1.Stop() + defer comm2.Stop() + time.Sleep(time.Duration(3) * time.Second) + msgs := make(chan *proto.GossipMessage, 2) + go func() { + m := <-comm2.Accept(acceptAll) + msgs <- m.GetGossipMessage() + }() + go func() { + m := <-comm1.Accept(acceptAll) + msgs <- m.GetGossipMessage() + }() + comm1.Send(createGossipMsg(), &RemotePeer{PKIID: []byte("localhost:3000"), Endpoint: "localhost:3000"}) + time.Sleep(time.Second) + comm2.Send(createGossipMsg(), &RemotePeer{PKIID: []byte("localhost:2000"), Endpoint: "localhost:2000"}) + time.Sleep(time.Second) + assert.Equal(t, 2, len(msgs)) +} + +func TestBlackListPKIid(t *testing.T) { + comm1, _ := newCommInstance(1611, naiveSec) + comm2, _ := newCommInstance(1612, naiveSec) + comm3, _ := newCommInstance(1613, naiveSec) + comm4, _ := newCommInstance(1614, naiveSec) + defer comm1.Stop() + defer comm2.Stop() + defer comm3.Stop() + defer comm4.Stop() + + reader := func(out chan uint64, in <-chan ReceivedMessage) { + for { + msg := <-in + if msg == nil { + return + } + out <- msg.GetGossipMessage().Nonce + } + } + + sender := func(comm Comm, port int, n int) { + endpoint := fmt.Sprintf("localhost:%d", port) + for i := 0; i < n; i++ { + comm.Send(createGossipMsg(), &RemotePeer{Endpoint: endpoint, PKIID: []byte(endpoint)}) + time.Sleep(time.Duration(1) * time.Second) + } + } + + out1 := make(chan uint64, 5) + out2 := make(chan uint64, 5) + out3 := make(chan uint64, 10) + out4 := make(chan uint64, 10) + + go reader(out1, comm1.Accept(acceptAll)) + go reader(out2, comm2.Accept(acceptAll)) + go reader(out3, comm3.Accept(acceptAll)) + go reader(out4, comm4.Accept(acceptAll)) + + // have comm1 BL comm3 + comm1.BlackListPKIid([]byte("localhost:1613")) + + // make comm3 send to 1 and 2 + go sender(comm3, 1611, 5) + go sender(comm3, 1612, 5) + + // make comm1 and comm2 send to comm3 + go sender(comm1, 1613, 5) + go sender(comm2, 1613, 5) + + // make comm1 and comm2 send to comm4 which is not blacklisted + go sender(comm1, 1614, 5) + go sender(comm2, 1614, 5) + + time.Sleep(time.Duration(1) * time.Second) + + // blacklist comm3 mid-sending + comm2.BlackListPKIid([]byte("localhost:1613")) + time.Sleep(time.Duration(5) * time.Second) + + assert.Equal(t, 0, len(out1), "Comm instance 1 received messages(%d) from comm3 although comm3 is black listed", len(out1)) + assert.True(t, len(out2) < 2, "Comm instance 2 received too many messages(%d) from comm3 although comm3 is black listed", len(out2)) + assert.True(t, len(out3) < 3, "Comm instance 3 received too many messages(%d) although black listed", len(out3)) + assert.Equal(t, 10, len(out4), "Comm instance 4 didn't receive all messages sent to it") +} + +func TestParallelSend(t *testing.T) { + comm1, _ := newCommInstance(5611, naiveSec) + comm2, _ := newCommInstance(5612, naiveSec) + defer comm1.Stop() + defer comm2.Stop() + + messages2Send := 100 + + wg := sync.WaitGroup{} + go func() { + for i := 0; i < messages2Send; i++ { + wg.Add(1) + emptyMsg := createGossipMsg() + go func() { + defer wg.Done() + comm1.Send(emptyMsg, &RemotePeer{Endpoint: "localhost:5612", PKIID: []byte("localhost:5612")}) + }() + } + wg.Wait() + }() + + c := 0 + waiting := true + ticker := time.NewTicker(time.Duration(1) * time.Second) + ch := comm2.Accept(acceptAll) + for waiting { + select { + case <-ch: + c++ + continue + case <-ticker.C: + waiting = false + break + } + } + assert.Equal(t, messages2Send, c) +} + +func TestResponses(t *testing.T) { + comm1, _ := newCommInstance(8611, naiveSec) + comm2, _ := newCommInstance(8612, naiveSec) + + defer comm1.Stop() + defer comm2.Stop() + + nonceIncrememter := func(msg ReceivedMessage) ReceivedMessage { + msg.GetGossipMessage().Nonce++ + return msg + } + + msg := createGossipMsg() + go func() { + inChan := comm1.Accept(acceptAll) + for m := range inChan { + m = nonceIncrememter(m) + m.Respond(m.GetGossipMessage()) + } + }() + expectedNOnce := uint64(msg.Nonce + 1) + responsesFromComm1 := comm2.Accept(acceptAll) + + ticker := time.NewTicker(time.Duration(6000) * time.Millisecond) + comm2.Send(msg, &RemotePeer{PKIID: []byte("localhost:8611"), Endpoint: "localhost:8611"}) + time.Sleep(time.Duration(100) * time.Millisecond) + + select { + case <-ticker.C: + assert.Fail(t, "Haven't got response from comm1 within a timely manner") + break + case resp := <-responsesFromComm1: + ticker.Stop() + assert.Equal(t, expectedNOnce, resp.GetGossipMessage().Nonce) + break + } +} + +func TestAccept(t *testing.T) { + comm1, _ := newCommInstance(7611, naiveSec) + comm2, _ := newCommInstance(7612, naiveSec) + + evenNONCESelector := func(m interface{}) bool { + return m.(ReceivedMessage).GetGossipMessage().Nonce%2 == 0 + } + + oddNONCESelector := func(m interface{}) bool { + return m.(ReceivedMessage).GetGossipMessage().Nonce%2 != 0 + } + + evenNONCES := comm1.Accept(evenNONCESelector) + oddNONCES := comm1.Accept(oddNONCESelector) + + var evenResults []uint64 + var oddResults []uint64 + + sem := make(chan struct{}, 0) + + readIntoSlice := func(a *[]uint64, ch <-chan ReceivedMessage) { + for m := range ch { + *a = append(*a, m.GetGossipMessage().Nonce) + } + sem <- struct{}{} + } + + go readIntoSlice(&evenResults, evenNONCES) + go readIntoSlice(&oddResults, oddNONCES) + + for i := 0; i < 100; i++ { + comm2.Send(createGossipMsg(), &RemotePeer{Endpoint: "localhost:7611", PKIID: []byte("localhost:7611")}) + } + + time.Sleep(time.Duration(1) * time.Second) + + comm1.Stop() + comm2.Stop() + + <-sem + <-sem + + assert.NotEmpty(t, evenResults) + assert.NotEmpty(t, oddResults) + + remainderPredicate := func(a []uint64, rem uint64) { + for _, n := range a { + assert.Equal(t, n%2, rem) + } + } + + remainderPredicate(evenResults, 0) + remainderPredicate(oddResults, 1) +} + +func TestReConnections(t *testing.T) { + comm1, _ := newCommInstance(3611, naiveSec) + comm2, _ := newCommInstance(3612, naiveSec) + + reader := func(out chan uint64, in <-chan ReceivedMessage) { + for { + msg := <-in + if msg == nil { + return + } + out <- msg.GetGossipMessage().Nonce + } + } + + out1 := make(chan uint64, 10) + out2 := make(chan uint64, 10) + + go reader(out1, comm1.Accept(acceptAll)) + go reader(out2, comm2.Accept(acceptAll)) + + // comm1 connects to comm2 + comm1.Send(createGossipMsg(), &RemotePeer{Endpoint: "localhost:3612", PKIID: []byte("localhost:3612")}) + time.Sleep(100 * time.Millisecond) + // comm2 sends to comm1 + comm2.Send(createGossipMsg(), &RemotePeer{Endpoint: "localhost:3611", PKIID: []byte("localhost:3611")}) + time.Sleep(100 * time.Millisecond) + + assert.Equal(t, 1, len(out2)) + assert.Equal(t, 1, len(out1)) + + comm1.Stop() + comm1, _ = newCommInstance(3611, naiveSec) + go reader(out1, comm1.Accept(acceptAll)) + time.Sleep(300 * time.Millisecond) + comm2.Send(createGossipMsg(), &RemotePeer{Endpoint: "localhost:3611", PKIID: []byte("localhost:3611")}) + time.Sleep(100 * time.Millisecond) + assert.Equal(t, 2, len(out1)) +} + +func TestProbe(t *testing.T) { + comm1, _ := newCommInstance(6611, naiveSec) + defer comm1.Stop() + comm2, _ := newCommInstance(6612, naiveSec) + time.Sleep(time.Duration(1) * time.Second) + assert.NoError(t, comm1.Probe("localhost:6612", []byte("localhost:6612"))) + assert.Error(t, comm1.Probe("localhost:9012", []byte("localhost:9012"))) + comm2.Stop() + time.Sleep(time.Second) + assert.Error(t, comm1.Probe("localhost:6612", []byte("localhost:6612"))) + comm2, _ = newCommInstance(6612, naiveSec) + defer comm2.Stop() + time.Sleep(time.Duration(1) * time.Second) + assert.NoError(t, comm2.Probe("localhost:6611", []byte("localhost:6611"))) + assert.NoError(t, comm1.Probe("localhost:6612", []byte("localhost:6612"))) +} + +func TestPresumedDead(t *testing.T) { + comm1, _ := newCommInstance(7611, naiveSec) + defer comm1.Stop() + comm2, _ := newCommInstance(7612, naiveSec) + go comm1.Send(createGossipMsg(), &RemotePeer{PKIID: []byte("localhost:7612"), Endpoint: "localhost:7612"}) + <-comm2.Accept(acceptAll) + comm2.Stop() + for i := 0; i < 5; i++ { + comm1.Send(createGossipMsg(), &RemotePeer{Endpoint: "localhost:7612", PKIID: []byte("localhost:7612")}) + time.Sleep(time.Second) + } + ticker := time.NewTicker(time.Second * time.Duration(3)) + select { + case <-ticker.C: + assert.Fail(t, "Didn't get a presumed dead message within a timely manner") + break + case <-comm1.PresumedDead(): + ticker.Stop() + break + } +} + +func createGossipMsg() *proto.GossipMessage { + return &proto.GossipMessage{ + Nonce: uint64(rand.Int()), + Content: &proto.GossipMessage_DataMsg{ + DataMsg: &proto.DataMessage{}, + }, + } +} diff --git a/gossip/comm/conn.go b/gossip/comm/conn.go new file mode 100644 index 00000000000..b37b2c7af7a --- /dev/null +++ b/gossip/comm/conn.go @@ -0,0 +1,330 @@ +/* +Copyright IBM Corp. 2016 All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package comm + +import ( + "fmt" + "sync" + "sync/atomic" + + "github.com/hyperledger/fabric/gossip/proto" + "github.com/hyperledger/fabric/gossip/util" + "google.golang.org/grpc" +) + +type handler func(*proto.GossipMessage) + +type connFactory interface { + createConnection(endpoint string, pkiID PKIidType) (*connection, error) +} + +type connectionStore struct { + logger *util.Logger // logger + selfPKIid PKIidType // pkiID of this peer + isClosing bool // whether this connection store is shutting down + connFactory connFactory // creates a connection to remote peer + sync.RWMutex // synchronize access to shared variables + pki2Conn map[string]*connection // mapping between pkiID to connections + destinationLocks map[string]*sync.RWMutex //mapping between pkiIDs and locks, + // used to prevent concurrent connection establishment to the same remote endpoint +} + +func newConnStore(connFactory connFactory, pkiID PKIidType, logger *util.Logger) *connectionStore { + return &connectionStore{ + connFactory: connFactory, + isClosing: false, + pki2Conn: make(map[string]*connection), + selfPKIid: pkiID, + destinationLocks: make(map[string]*sync.RWMutex), + logger: logger, + } +} + +func (cs *connectionStore) getConnection(peer *RemotePeer) (*connection, error) { + cs.RLock() + isClosing := cs.isClosing + cs.RUnlock() + + if isClosing { + return nil, fmt.Errorf("Shutting down") + } + + pkiID := peer.PKIID + endpoint := peer.Endpoint + + cs.Lock() + destinationLock, hasConnected := cs.destinationLocks[string(pkiID)] + if !hasConnected { + destinationLock = &sync.RWMutex{} + cs.destinationLocks[string(pkiID)] = destinationLock + } + cs.Unlock() + + destinationLock.Lock() + + cs.RLock() + conn, exists := cs.pki2Conn[string(pkiID)] + if exists { + cs.RUnlock() + destinationLock.Unlock() + return conn, nil + } + cs.RUnlock() + + createdConnection, err := cs.connFactory.createConnection(endpoint, pkiID) + + destinationLock.Unlock() + + cs.Lock() + delete(cs.destinationLocks, string(pkiID)) + defer cs.Unlock() + + // check again, maybe someone connected to us during the connection creation? + conn, exists = cs.pki2Conn[string(pkiID)] + + if exists { + if createdConnection != nil { + createdConnection.close() + } + return conn, nil + } + + // no one connected to us AND we failed connecting! + if err != nil { + return nil, err + } + + // at this point in the code, we created a connection to a remote peer + conn = createdConnection + cs.pki2Conn[string(createdConnection.pkiID)] = conn + + go conn.serviceInput() + + return conn, nil +} + +func (cs *connectionStore) connNum() int { + cs.RLock() + defer cs.RUnlock() + return len(cs.pki2Conn) +} + +func (cs *connectionStore) closeConn(peer *RemotePeer) { + cs.Lock() + defer cs.Unlock() + + if conn, exists := cs.pki2Conn[string(peer.PKIID)]; exists { + conn.close() + delete(cs.pki2Conn, string(conn.pkiID)) + } +} + +func (cs *connectionStore) shutdown() { + cs.Lock() + cs.isClosing = true + pkiIds2conn := cs.pki2Conn + cs.Unlock() + + wg := sync.WaitGroup{} + for _, conn := range pkiIds2conn { + wg.Add(1) + go func(conn *connection) { + cs.closeByPKIid(conn.pkiID) + wg.Done() + }(conn) + } + wg.Wait() +} + +func (cs *connectionStore) onConnected(serverStream proto.Gossip_GossipStreamServer, pkiID PKIidType) *connection { + cs.Lock() + defer cs.Unlock() + + if c, exists := cs.pki2Conn[string(pkiID)]; exists { + c.close() + } + + return cs.registerConn(pkiID, serverStream) +} + +func (cs *connectionStore) registerConn(pkiID PKIidType, serverStream proto.Gossip_GossipStreamServer) *connection { + conn := newConnection(nil, nil, nil, serverStream) + conn.pkiID = pkiID + conn.logger = cs.logger + cs.pki2Conn[string(pkiID)] = conn + return conn +} + +func (cs *connectionStore) closeByPKIid(pkiID PKIidType) { + cs.Lock() + defer cs.Unlock() + if conn, exists := cs.pki2Conn[string(pkiID)]; exists { + conn.close() + delete(cs.pki2Conn, string(pkiID)) + } +} + +func newConnection(cl proto.GossipClient, c *grpc.ClientConn, cs proto.Gossip_GossipStreamClient, ss proto.Gossip_GossipStreamServer) *connection { + connection := &connection{ + cl: cl, + conn: c, + clientStream: cs, + serverStream: ss, + stopFlag: int32(0), + stopChan: make(chan struct{}, 1), + } + + return connection +} + +type connection struct { + logger *util.Logger // logger + pkiID PKIidType // pkiID of the remote endpoint + handler handler // function to invoke upon a message reception + conn *grpc.ClientConn // gRPC connection to remote endpoint + cl proto.GossipClient // gRPC stub of remote endpoint + clientStream proto.Gossip_GossipStreamClient // client-side stream to remote endpoint + serverStream proto.Gossip_GossipStreamServer // server-side stream to remote endpoint + stopFlag int32 // indicates whether this connection is in process of stopping + stopChan chan struct{} // a method to stop the server-side gRPC call from a different go-routine + sync.RWMutex // synchronizes access to shared variables +} + +func (conn *connection) close() { + if conn.toDie() { + return + } + + amIFirst := atomic.CompareAndSwapInt32(&conn.stopFlag, int32(0), int32(1)) + if !amIFirst { + return + } + + conn.stopChan <- struct{}{} + + conn.Lock() + + if conn.clientStream != nil { + conn.clientStream.CloseSend() + } + if conn.conn != nil { + conn.conn.Close() + } + + conn.Unlock() + +} + +func (conn *connection) toDie() bool { + return atomic.LoadInt32(&(conn.stopFlag)) == int32(1) +} + +func (conn *connection) send(msg *proto.GossipMessage) error { + conn.Lock() + defer conn.Unlock() + + if conn.toDie() { + return fmt.Errorf("Connection aborted") + } + + if conn.clientStream != nil { + return conn.clientStream.Send(msg) + } + + if conn.serverStream != nil { + return conn.serverStream.Send(msg) + } + + return fmt.Errorf("Both streams are nil") +} + +func (conn *connection) serviceInput() error { + errChan := make(chan error, 1) + msgChan := make(chan *proto.GossipMessage, defRecvBuffSize) + defer close(msgChan) + + // Call stream.Recv() asynchronously in readFromStream(), + // and wait for either the Recv() call to end, + // or a signal to close the connection, which exits + // the method and makes the Recv() call to fail in the + // readFromStream() method + go conn.readFromStream(errChan, msgChan) + + for !conn.toDie() { + select { + case stop := <-conn.stopChan: + conn.logger.Warning("Closing reading from stream") + conn.stopChan <- stop + return nil + case err := <-errChan: + return err + case msg := <-msgChan: + conn.handler(msg) + } + } + return nil +} + +func (conn *connection) readFromStream(errChan chan error, msgChan chan *proto.GossipMessage) { + defer func() { + recover() + }() // msgChan might be closed + for !conn.toDie() { + stream := conn.getStream() + if stream == nil { + conn.logger.Error(conn.pkiID, "Stream is nil, aborting!") + errChan <- fmt.Errorf("Stream is nil") + return + } + msg, err := stream.Recv() + if conn.toDie() { + conn.logger.Warning(conn.pkiID, "canceling read because closing") + return + } + if err != nil { + errChan <- err + conn.logger.Warning(conn.pkiID, "Got error, aborting:", err) + return + } + msgChan <- msg + } +} + +func (conn *connection) getStream() stream { + conn.Lock() + defer conn.Unlock() + + if conn.clientStream != nil && conn.serverStream != nil { + e := "Both client and server stream are not nil, something went wrong" + conn.logger.Error(e) + } + + if conn.clientStream != nil { + return conn.clientStream + } + + if conn.serverStream != nil { + return conn.serverStream + } + + return nil +} + +type msgSending struct { + msg *proto.GossipMessage + onErr func(error) +} diff --git a/gossip/comm/crypto.go b/gossip/comm/crypto.go new file mode 100644 index 00000000000..53eb40a0b1d --- /dev/null +++ b/gossip/comm/crypto.go @@ -0,0 +1,111 @@ +/* +Copyright IBM Corp. 2016 All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package comm + +import ( + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/x509" + "encoding/pem" + "math/big" + "os" + + "crypto/tls" + "net" + "time" + + "golang.org/x/net/context" + "google.golang.org/grpc/credentials" + "google.golang.org/grpc/peer" +) + +func writeFile(filename string, keyType string, data []byte) error { + f, err := os.Create(filename) + if err != nil { + return err + } + defer f.Close() + return pem.Encode(f, &pem.Block{Type: keyType, Bytes: data}) +} + +func generateCertificates(privKeyFile string, certKeyFile string) error { + privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + return err + } + + sn, err := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 128)) + template := x509.Certificate{ + KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, + SerialNumber: sn, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + } + rawBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &privateKey.PublicKey, privateKey) + if err != nil { + return err + } + err = writeFile(certKeyFile, "CERTIFICATE", rawBytes) + if err != nil { + return err + } + privBytes, err := x509.MarshalECPrivateKey(privateKey) + if err != nil { + return err + } + err = writeFile(privKeyFile, "EC PRIVATE KEY", privBytes) + return err +} + +// ExtractTLSUnique extracts the TLS-Unique from the stream +func ExtractTLSUnique(ctx context.Context) []byte { + pr, extracted := peer.FromContext(ctx) + if !extracted { + return nil + } + + authInfo := pr.AuthInfo + if authInfo == nil { + return nil + } + + tlsInfo, isTLSConn := authInfo.(credentials.TLSInfo) + if !isTLSConn { + return nil + } + return tlsInfo.State.TLSUnique +} + +type authCreds struct { + tlsCreds credentials.TransportCredentials +} + +func (c authCreds) Info() credentials.ProtocolInfo { + return c.tlsCreds.Info() +} + +func (c *authCreds) ClientHandshake(addr string, rawConn net.Conn, timeout time.Duration) (_ net.Conn, _ credentials.AuthInfo, err error) { + conn, auth, err := c.tlsCreds.ClientHandshake(addr, rawConn, timeout) + if auth == nil && conn != nil { + auth = credentials.TLSInfo{State: conn.(*tls.Conn).ConnectionState()} + } + return conn, auth, err +} + +func (c *authCreds) ServerHandshake(rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) { + return c.tlsCreds.ServerHandshake(rawConn) +} diff --git a/gossip/comm/crypto_test.go b/gossip/comm/crypto_test.go new file mode 100644 index 00000000000..d5b96603d7a --- /dev/null +++ b/gossip/comm/crypto_test.go @@ -0,0 +1,132 @@ +/* +Copyright IBM Corp. 2016 All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package comm + +import ( + "bytes" + "crypto/tls" + "fmt" + "net" + "os" + "sync" + "testing" + "time" + + "github.com/hyperledger/fabric/gossip/proto" + "github.com/stretchr/testify/assert" + "golang.org/x/net/context" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials" +) + +type gossipTestServer struct { + lock sync.Mutex + msgChan chan uint64 + tlsUnique []byte +} + +func (s *gossipTestServer) GossipStream(stream proto.Gossip_GossipStreamServer) error { + s.lock.Lock() + s.tlsUnique = ExtractTLSUnique(stream.Context()) + s.lock.Unlock() + m, err := stream.Recv() + if err != nil { + fmt.Println(err) + } else { + s.msgChan <- m.Nonce + } + + return nil +} + +func (s *gossipTestServer) getTLSUnique() []byte { + s.lock.Lock() + defer s.lock.Unlock() + return s.tlsUnique +} + +func (s *gossipTestServer) Ping(context.Context, *proto.Empty) (*proto.Empty, error) { + return &proto.Empty{}, nil +} + +func TestCertificateGeneration(t *testing.T) { + err := generateCertificates("key.pem", "cert.pem") + assert.NoError(t, err, "%v", err) + if err != nil { + return + } + defer os.Remove("cert.pem") + defer os.Remove("key.pem") + var ll net.Listener + creds, err := credentials.NewServerTLSFromFile("cert.pem", "key.pem") + assert.NoError(t, err, "%v", err) + if err != nil { + return + } + s := grpc.NewServer(grpc.Creds(creds)) + ll, err = net.Listen("tcp", fmt.Sprintf("%s:%d", "", 5611)) + assert.NoError(t, err, "%v", err) + if err != nil { + return + } + srv := &gossipTestServer{msgChan: make(chan uint64)} + proto.RegisterGossipServer(s, srv) + go s.Serve(ll) + defer func() { + s.Stop() + ll.Close() + }() + time.Sleep(time.Second * time.Duration(2)) + ta := credentials.NewTLS(&tls.Config{ + InsecureSkipVerify: true, + }) + assert.NoError(t, err, "%v", err) + if err != nil { + return + } + conn, err := grpc.Dial("localhost:5611", grpc.WithTransportCredentials(&authCreds{tlsCreds: ta}), grpc.WithBlock(), grpc.WithTimeout(time.Second)) + assert.NoError(t, err, "%v", err) + if err != nil { + return + } + cl := proto.NewGossipClient(conn) + stream, err := cl.GossipStream(context.Background()) + assert.NoError(t, err, "%v", err) + if err != nil { + return + } + + time.Sleep(time.Duration(1) * time.Second) + + clientTLSUnique := ExtractTLSUnique(stream.Context()) + serverTLSUnique := srv.getTLSUnique() + + assert.NotNil(t, clientTLSUnique) + assert.NotNil(t, serverTLSUnique) + + assert.True(t, bytes.Equal(clientTLSUnique, serverTLSUnique), "Client and server TLSUnique are not equal") + + msg := createGossipMsg() + stream.Send(msg) + select { + case nonce := <-srv.msgChan: + assert.Equal(t, msg.Nonce, nonce) + break + case <-time.NewTicker(time.Second).C: + assert.Fail(t, "Timed out reading from stream") + } +} diff --git a/gossip/comm/demux.go b/gossip/comm/demux.go new file mode 100644 index 00000000000..b6976d7fb20 --- /dev/null +++ b/gossip/comm/demux.go @@ -0,0 +1,97 @@ +/* +Copyright IBM Corp. 2016 All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package comm + +import ( + "sync" + "sync/atomic" + + "github.com/hyperledger/fabric/gossip/util" +) + +// ChannelDeMultiplexer is a struct that can receive channel registrations (AddChannel) +// and publications (DeMultiplex) and it broadcasts the publications to registrations +// according to their predicate +type ChannelDeMultiplexer struct { + channels []*channel + lock *sync.RWMutex + closed int32 +} + +// NewChannelDemultiplexer creates a new ChannelDeMultiplexer +func NewChannelDemultiplexer() *ChannelDeMultiplexer { + return &ChannelDeMultiplexer{ + channels: make([]*channel, 0), + lock: &sync.RWMutex{}, + closed: int32(0), + } +} + +type channel struct { + pred util.MessageAcceptor + ch chan interface{} +} + +func (m *ChannelDeMultiplexer) isClosed() bool { + return atomic.LoadInt32(&m.closed) == int32(1) +} + +// Close closes this channel, which makes all channels registered before +// to close as well. +func (m *ChannelDeMultiplexer) Close() { + defer func() { + // recover closing an already closed channel + recover() + }() + atomic.StoreInt32(&m.closed, int32(1)) + m.lock.Lock() + defer m.lock.Unlock() + for _, ch := range m.channels { + close(ch.ch) + } +} + +// AddChannel registers a channel with a certain predicate +func (m *ChannelDeMultiplexer) AddChannel(predicate util.MessageAcceptor) chan interface{} { + m.lock.Lock() + defer m.lock.Unlock() + ch := &channel{ch: make(chan interface{}, 10), pred: predicate} + m.channels = append(m.channels, ch) + return ch.ch +} + +// DeMultiplex broadcasts the message to all channels that were returned +// by AddChannel calls and that hold the respected predicates. +func (m *ChannelDeMultiplexer) DeMultiplex(msg interface{}) { + defer func() { + recover() + }() // recover from sending on a closed channel + + if m.isClosed() { + return + } + + m.lock.RLock() + channels := m.channels + m.lock.RUnlock() + + for _, ch := range channels { + if ch.pred(msg) { + ch.ch <- msg + } + } +} diff --git a/gossip/comm/msg.go b/gossip/comm/msg.go new file mode 100644 index 00000000000..61fbe1a2f4c --- /dev/null +++ b/gossip/comm/msg.go @@ -0,0 +1,40 @@ +/* +Copyright IBM Corp. 2016 All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package comm + +import ( + "sync" + + "github.com/hyperledger/fabric/gossip/proto" +) + +// ReceivedMessageImpl is an implementation of ReceivedMessage +type ReceivedMessageImpl struct { + *proto.GossipMessage + lock sync.Locker + conn *connection +} + +// Respond sends a msg to the source that sent the ReceivedMessageImpl +func (m *ReceivedMessageImpl) Respond(msg *proto.GossipMessage) { + m.conn.send(msg) +} + +// GetGossipMessage returns the inner GossipMessage +func (m *ReceivedMessageImpl) GetGossipMessage() *proto.GossipMessage { + return m.GossipMessage +} diff --git a/gossip/util/logging.go b/gossip/util/logging.go index 99d23ac1a78..156e8af8fa6 100644 --- a/gossip/util/logging.go +++ b/gossip/util/logging.go @@ -18,16 +18,22 @@ package util import ( "fmt" - "github.com/op/go-logging" + "io/ioutil" + "log" "os" "sync" + + "github.com/op/go-logging" + + "google.golang.org/grpc/grpclog" ) const ( LOGGING_MESSAGE_BUFF_MODULE = "mbuff" LOGGING_EMITTER_MODULE = "emitter" - LOGGING_GOSMEMBER_MODULE = "gossip" + LOGGING_GOSSIP_MODULE = "gossip" LOGGING_DISCOVERY_MODULE = "discovery" + LOGGING_COMM_MODULE = "comm" ) var loggersByModules = make(map[string]*Logger) @@ -35,11 +41,12 @@ var defaultLevel = logging.WARNING var lock = sync.Mutex{} var format = logging.MustStringFormatter( - `%{color}%{level} %{longfunc}():%{color:reset}(%{module})%{message}`, + `%{color} %{level} %{longfunc}():%{color:reset}(%{module})%{message}`, ) func init() { logging.SetFormatter(format) + grpclog.SetLogger(log.New(ioutil.Discard, "", 0)) } type Logger struct {