diff --git a/waku/v2/protocol/filter/client.go b/waku/v2/protocol/filter/client.go index e1c1bb6a0..6c43a7e48 100644 --- a/waku/v2/protocol/filter/client.go +++ b/waku/v2/protocol/filter/client.go @@ -106,25 +106,34 @@ func (wf *WakuFilterLightNode) Stop() { }) } -func (wf *WakuFilterLightNode) onRequest(ctx context.Context) func(s network.Stream) { - return func(s network.Stream) { - defer s.Close() - logger := wf.log.With(logging.HostID("peer", s.Conn().RemotePeer())) - if !wf.subscriptions.IsSubscribedTo(s.Conn().RemotePeer()) { - logger.Warn("received message push from unknown peer", logging.HostID("peerID", s.Conn().RemotePeer())) +func (wf *WakuFilterLightNode) onRequest(ctx context.Context) func(network.Stream) { + return func(stream network.Stream) { + peerID := stream.Conn().RemotePeer() + logger := wf.log.With(logging.HostID("peer", peerID)) + if !wf.subscriptions.IsSubscribedTo(peerID) { + logger.Warn("received message push from unknown peer", logging.HostID("peerID", peerID)) wf.metrics.RecordError(unknownPeerMessagePush) + if err := stream.Reset(); err != nil { + wf.log.Error("resetting connection", zap.Error(err)) + } return } - reader := pbio.NewDelimitedReader(s, math.MaxInt32) + reader := pbio.NewDelimitedReader(stream, math.MaxInt32) messagePush := &pb.MessagePushV2{} err := reader.ReadMsg(messagePush) if err != nil { logger.Error("reading message push", zap.Error(err)) wf.metrics.RecordError(decodeRPCFailure) + if err := stream.Reset(); err != nil { + wf.log.Error("resetting connection", zap.Error(err)) + } return } + + stream.Close() + pubSubTopic := "" //For now returning failure, this will get addressed with autosharding changes for filter. if messagePush.PubsubTopic == nil { @@ -132,14 +141,17 @@ func (wf *WakuFilterLightNode) onRequest(ctx context.Context) func(s network.Str if err != nil { logger.Error("could not derive pubSubTopic from contentTopic", zap.Error(err)) wf.metrics.RecordError(decodeRPCFailure) + if err := stream.Reset(); err != nil { + wf.log.Error("resetting connection", zap.Error(err)) + } return } } else { pubSubTopic = *messagePush.PubsubTopic } - if !wf.subscriptions.Has(s.Conn().RemotePeer(), protocol.NewContentFilter(pubSubTopic, messagePush.WakuMessage.ContentTopic)) { + if !wf.subscriptions.Has(peerID, protocol.NewContentFilter(pubSubTopic, messagePush.WakuMessage.ContentTopic)) { logger.Warn("received messagepush with invalid subscription parameters", - logging.HostID("peerID", s.Conn().RemotePeer()), zap.String("topic", pubSubTopic), + zap.String("topic", pubSubTopic), zap.String("contentTopic", messagePush.WakuMessage.ContentTopic)) wf.metrics.RecordError(invalidSubscriptionMessage) return @@ -147,7 +159,7 @@ func (wf *WakuFilterLightNode) onRequest(ctx context.Context) func(s network.Str wf.metrics.RecordMessage() - wf.notify(s.Conn().RemotePeer(), pubSubTopic, messagePush.WakuMessage) + wf.notify(peerID, pubSubTopic, messagePush.WakuMessage) logger.Info("received message push") } @@ -166,15 +178,14 @@ func (wf *WakuFilterLightNode) notify(remotePeerID peer.ID, pubsubTopic string, func (wf *WakuFilterLightNode) request(ctx context.Context, params *FilterSubscribeParameters, reqType pb.FilterSubscribeRequest_FilterSubscribeType, contentFilter protocol.ContentFilter) error { - conn, err := wf.h.NewStream(ctx, params.selectedPeer, FilterSubscribeID_v20beta1) + stream, err := wf.h.NewStream(ctx, params.selectedPeer, FilterSubscribeID_v20beta1) if err != nil { wf.metrics.RecordError(dialFailure) return err } - defer conn.Close() - writer := pbio.NewDelimitedWriter(conn) - reader := pbio.NewDelimitedReader(conn, math.MaxInt32) + writer := pbio.NewDelimitedWriter(stream) + reader := pbio.NewDelimitedReader(stream, math.MaxInt32) request := &pb.FilterSubscribeRequest{ RequestId: hex.EncodeToString(params.requestID), @@ -188,6 +199,9 @@ func (wf *WakuFilterLightNode) request(ctx context.Context, params *FilterSubscr if err != nil { wf.metrics.RecordError(writeRequestFailure) wf.log.Error("sending FilterSubscribeRequest", zap.Error(err)) + if err := stream.Reset(); err != nil { + wf.log.Error("resetting connection", zap.Error(err)) + } return err } @@ -196,8 +210,14 @@ func (wf *WakuFilterLightNode) request(ctx context.Context, params *FilterSubscr if err != nil { wf.log.Error("receiving FilterSubscribeResponse", zap.Error(err)) wf.metrics.RecordError(decodeRPCFailure) + if err := stream.Reset(); err != nil { + wf.log.Error("resetting connection", zap.Error(err)) + } return err } + + stream.Close() + if filterSubscribeResponse.RequestId != request.RequestId { wf.log.Error("requestID mismatch", zap.String("expected", request.RequestId), zap.String("received", filterSubscribeResponse.RequestId)) wf.metrics.RecordError(requestIDMismatch) diff --git a/waku/v2/protocol/filter/server.go b/waku/v2/protocol/filter/server.go index 221665706..324bb3d11 100644 --- a/waku/v2/protocol/filter/server.go +++ b/waku/v2/protocol/filter/server.go @@ -83,18 +83,20 @@ func (wf *WakuFilterFullNode) start(sub *relay.Subscription) error { return nil } -func (wf *WakuFilterFullNode) onRequest(ctx context.Context) func(s network.Stream) { - return func(s network.Stream) { - defer s.Close() - logger := wf.log.With(logging.HostID("peer", s.Conn().RemotePeer())) +func (wf *WakuFilterFullNode) onRequest(ctx context.Context) func(network.Stream) { + return func(stream network.Stream) { + logger := wf.log.With(logging.HostID("peer", stream.Conn().RemotePeer())) - reader := pbio.NewDelimitedReader(s, math.MaxInt32) + reader := pbio.NewDelimitedReader(stream, math.MaxInt32) subscribeRequest := &pb.FilterSubscribeRequest{} err := reader.ReadMsg(subscribeRequest) if err != nil { wf.metrics.RecordError(decodeRPCFailure) logger.Error("reading request", zap.Error(err)) + if err := stream.Reset(); err != nil { + wf.log.Error("resetting connection", zap.Error(err)) + } return } @@ -104,22 +106,24 @@ func (wf *WakuFilterFullNode) onRequest(ctx context.Context) func(s network.Stre switch subscribeRequest.FilterSubscribeType { case pb.FilterSubscribeRequest_SUBSCRIBE: - wf.subscribe(ctx, s, subscribeRequest) + wf.subscribe(ctx, stream, subscribeRequest) case pb.FilterSubscribeRequest_SUBSCRIBER_PING: - wf.ping(ctx, s, subscribeRequest) + wf.ping(ctx, stream, subscribeRequest) case pb.FilterSubscribeRequest_UNSUBSCRIBE: - wf.unsubscribe(ctx, s, subscribeRequest) + wf.unsubscribe(ctx, stream, subscribeRequest) case pb.FilterSubscribeRequest_UNSUBSCRIBE_ALL: - wf.unsubscribeAll(ctx, s, subscribeRequest) + wf.unsubscribeAll(ctx, stream, subscribeRequest) } + stream.Close() + wf.metrics.RecordRequest(subscribeRequest.FilterSubscribeType.String(), time.Since(start)) logger.Info("received request", zap.String("requestType", subscribeRequest.FilterSubscribeType.String())) } } -func (wf *WakuFilterFullNode) reply(ctx context.Context, s network.Stream, request *pb.FilterSubscribeRequest, statusCode int, description ...string) { +func (wf *WakuFilterFullNode) reply(ctx context.Context, stream network.Stream, request *pb.FilterSubscribeRequest, statusCode int, description ...string) { response := &pb.FilterSubscribeResponse{ RequestId: request.RequestId, StatusCode: uint32(statusCode), @@ -131,45 +135,48 @@ func (wf *WakuFilterFullNode) reply(ctx context.Context, s network.Stream, reque response.StatusDesc = http.StatusText(statusCode) } - writer := pbio.NewDelimitedWriter(s) + writer := pbio.NewDelimitedWriter(stream) err := writer.WriteMsg(response) if err != nil { wf.metrics.RecordError(writeResponseFailure) wf.log.Error("sending response", zap.Error(err)) + if err := stream.Reset(); err != nil { + wf.log.Error("resetting connection", zap.Error(err)) + } } } -func (wf *WakuFilterFullNode) ping(ctx context.Context, s network.Stream, request *pb.FilterSubscribeRequest) { - exists := wf.subscriptions.Has(s.Conn().RemotePeer()) +func (wf *WakuFilterFullNode) ping(ctx context.Context, stream network.Stream, request *pb.FilterSubscribeRequest) { + exists := wf.subscriptions.Has(stream.Conn().RemotePeer()) if exists { - wf.reply(ctx, s, request, http.StatusOK) + wf.reply(ctx, stream, request, http.StatusOK) } else { - wf.reply(ctx, s, request, http.StatusNotFound, peerHasNoSubscription) + wf.reply(ctx, stream, request, http.StatusNotFound, peerHasNoSubscription) } } -func (wf *WakuFilterFullNode) subscribe(ctx context.Context, s network.Stream, request *pb.FilterSubscribeRequest) { +func (wf *WakuFilterFullNode) subscribe(ctx context.Context, stream network.Stream, request *pb.FilterSubscribeRequest) { if request.PubsubTopic == nil { - wf.reply(ctx, s, request, http.StatusBadRequest, "pubsubtopic can't be empty") + wf.reply(ctx, stream, request, http.StatusBadRequest, "pubsubtopic can't be empty") return } if len(request.ContentTopics) == 0 { - wf.reply(ctx, s, request, http.StatusBadRequest, "at least one contenttopic should be specified") + wf.reply(ctx, stream, request, http.StatusBadRequest, "at least one contenttopic should be specified") return } if len(request.ContentTopics) > MaxContentTopicsPerRequest { - wf.reply(ctx, s, request, http.StatusBadRequest, fmt.Sprintf("exceeds maximum content topics: %d", MaxContentTopicsPerRequest)) + wf.reply(ctx, stream, request, http.StatusBadRequest, fmt.Sprintf("exceeds maximum content topics: %d", MaxContentTopicsPerRequest)) } if wf.subscriptions.Count() >= wf.maxSubscriptions { - wf.reply(ctx, s, request, http.StatusServiceUnavailable, "node has reached maximum number of subscriptions") + wf.reply(ctx, stream, request, http.StatusServiceUnavailable, "node has reached maximum number of subscriptions") return } - peerID := s.Conn().RemotePeer() + peerID := stream.Conn().RemotePeer() if totalSubs, exists := wf.subscriptions.Get(peerID); exists { ctTotal := 0 @@ -178,7 +185,7 @@ func (wf *WakuFilterFullNode) subscribe(ctx context.Context, s network.Stream, r } if ctTotal+len(request.ContentTopics) > MaxCriteriaPerSubscription { - wf.reply(ctx, s, request, http.StatusServiceUnavailable, "peer has reached maximum number of filter criteria") + wf.reply(ctx, stream, request, http.StatusServiceUnavailable, "peer has reached maximum number of filter criteria") return } } @@ -186,40 +193,40 @@ func (wf *WakuFilterFullNode) subscribe(ctx context.Context, s network.Stream, r wf.subscriptions.Set(peerID, *request.PubsubTopic, request.ContentTopics) wf.metrics.RecordSubscriptions(wf.subscriptions.Count()) - wf.reply(ctx, s, request, http.StatusOK) + wf.reply(ctx, stream, request, http.StatusOK) } -func (wf *WakuFilterFullNode) unsubscribe(ctx context.Context, s network.Stream, request *pb.FilterSubscribeRequest) { +func (wf *WakuFilterFullNode) unsubscribe(ctx context.Context, stream network.Stream, request *pb.FilterSubscribeRequest) { if request.PubsubTopic == nil { - wf.reply(ctx, s, request, http.StatusBadRequest, "pubsubtopic can't be empty") + wf.reply(ctx, stream, request, http.StatusBadRequest, "pubsubtopic can't be empty") return } if len(request.ContentTopics) == 0 { - wf.reply(ctx, s, request, http.StatusBadRequest, "at least one contenttopic should be specified") + wf.reply(ctx, stream, request, http.StatusBadRequest, "at least one contenttopic should be specified") return } if len(request.ContentTopics) > MaxContentTopicsPerRequest { - wf.reply(ctx, s, request, http.StatusBadRequest, fmt.Sprintf("exceeds maximum content topics: %d", MaxContentTopicsPerRequest)) + wf.reply(ctx, stream, request, http.StatusBadRequest, fmt.Sprintf("exceeds maximum content topics: %d", MaxContentTopicsPerRequest)) } - err := wf.subscriptions.Delete(s.Conn().RemotePeer(), *request.PubsubTopic, request.ContentTopics) + err := wf.subscriptions.Delete(stream.Conn().RemotePeer(), *request.PubsubTopic, request.ContentTopics) if err != nil { - wf.reply(ctx, s, request, http.StatusNotFound, peerHasNoSubscription) + wf.reply(ctx, stream, request, http.StatusNotFound, peerHasNoSubscription) } else { wf.metrics.RecordSubscriptions(wf.subscriptions.Count()) - wf.reply(ctx, s, request, http.StatusOK) + wf.reply(ctx, stream, request, http.StatusOK) } } -func (wf *WakuFilterFullNode) unsubscribeAll(ctx context.Context, s network.Stream, request *pb.FilterSubscribeRequest) { - err := wf.subscriptions.DeleteAll(s.Conn().RemotePeer()) +func (wf *WakuFilterFullNode) unsubscribeAll(ctx context.Context, stream network.Stream, request *pb.FilterSubscribeRequest) { + err := wf.subscriptions.DeleteAll(stream.Conn().RemotePeer()) if err != nil { - wf.reply(ctx, s, request, http.StatusNotFound, peerHasNoSubscription) + wf.reply(ctx, stream, request, http.StatusNotFound, peerHasNoSubscription) } else { wf.metrics.RecordSubscriptions(wf.subscriptions.Count()) - wf.reply(ctx, s, request, http.StatusOK) + wf.reply(ctx, stream, request, http.StatusOK) } } @@ -279,7 +286,7 @@ func (wf *WakuFilterFullNode) pushMessage(ctx context.Context, peerID peer.ID, e ctx, cancel := context.WithTimeout(ctx, MessagePushTimeout) defer cancel() - conn, err := wf.h.NewStream(ctx, peerID, FilterPushID_v20beta1) + stream, err := wf.h.NewStream(ctx, peerID, FilterPushID_v20beta1) if err != nil { wf.subscriptions.FlagAsFailure(peerID) if errors.Is(context.DeadlineExceeded, err) { @@ -291,8 +298,7 @@ func (wf *WakuFilterFullNode) pushMessage(ctx context.Context, peerID peer.ID, e return err } - defer conn.Close() - writer := pbio.NewDelimitedWriter(conn) + writer := pbio.NewDelimitedWriter(stream) err = writer.WriteMsg(messagePush) if err != nil { if errors.Is(context.DeadlineExceeded, err) { @@ -302,12 +308,17 @@ func (wf *WakuFilterFullNode) pushMessage(ctx context.Context, peerID peer.ID, e } logger.Error("pushing messages to peer", zap.Error(err)) wf.subscriptions.FlagAsFailure(peerID) + if err := stream.Reset(); err != nil { + wf.log.Error("resetting connection", zap.Error(err)) + } return nil } + stream.Close() + wf.subscriptions.FlagAsSuccess(peerID) - logger.Info("message pushed succesfully") // TODO: remove or change to debug once dogfooding of filter is complete + logger.Debug("message pushed succesfully") return nil } diff --git a/waku/v2/protocol/legacy_filter/waku_filter.go b/waku/v2/protocol/legacy_filter/waku_filter.go index 7b085d959..1e5d92d5f 100644 --- a/waku/v2/protocol/legacy_filter/waku_filter.go +++ b/waku/v2/protocol/legacy_filter/waku_filter.go @@ -103,19 +103,22 @@ func (wf *WakuFilter) start(sub *relay.Subscription) error { wf.log.Info("filter protocol started") return nil } -func (wf *WakuFilter) onRequest(ctx context.Context) func(s network.Stream) { - return func(s network.Stream) { - defer s.Close() - logger := wf.log.With(logging.HostID("peer", s.Conn().RemotePeer())) +func (wf *WakuFilter) onRequest(ctx context.Context) func(network.Stream) { + return func(stream network.Stream) { + peerID := stream.Conn().RemotePeer() + logger := wf.log.With(logging.HostID("peer", peerID)) filterRPCRequest := &pb.FilterRPC{} - reader := pbio.NewDelimitedReader(s, math.MaxInt32) + reader := pbio.NewDelimitedReader(stream, math.MaxInt32) err := reader.ReadMsg(filterRPCRequest) if err != nil { wf.metrics.RecordError(decodeRPCFailure) logger.Error("reading request", zap.Error(err)) + if err := stream.Reset(); err != nil { + wf.log.Error("resetting connection", zap.Error(err)) + } return } @@ -134,7 +137,7 @@ func (wf *WakuFilter) onRequest(ctx context.Context) func(s network.Stream) { // We're on a full node. // This is a filter request coming from a light node. if filterRPCRequest.Request.Subscribe { - subscriber := Subscriber{peer: s.Conn().RemotePeer(), requestID: filterRPCRequest.RequestId, filter: filterRPCRequest.Request} + subscriber := Subscriber{peer: stream.Conn().RemotePeer(), requestID: filterRPCRequest.RequestId, filter: filterRPCRequest.Request} if subscriber.filter.Topic == "" { // @TODO: review if empty topic is possible subscriber.filter.Topic = relay.DefaultWakuTopic } @@ -144,7 +147,6 @@ func (wf *WakuFilter) onRequest(ctx context.Context) func(s network.Stream) { logger.Info("adding subscriber") wf.metrics.RecordSubscribers(subscribersLen) } else { - peerID := s.Conn().RemotePeer() wf.subscribers.RemoveContentFilters(peerID, filterRPCRequest.RequestId, filterRPCRequest.Request.ContentFilters) logger.Info("removing subscriber") @@ -152,8 +154,13 @@ func (wf *WakuFilter) onRequest(ctx context.Context) func(s network.Stream) { } } else { logger.Error("can't serve request") + if err := stream.Reset(); err != nil { + wf.log.Error("resetting connection", zap.Error(err)) + } return } + + stream.Close() } } @@ -161,7 +168,7 @@ func (wf *WakuFilter) pushMessage(ctx context.Context, subscriber Subscriber, ms pushRPC := &pb.FilterRPC{RequestId: subscriber.requestID, Push: &pb.MessagePush{Messages: []*wpb.WakuMessage{msg}}} logger := wf.log.With(logging.HostID("peer", subscriber.peer)) - conn, err := wf.h.NewStream(ctx, subscriber.peer, FilterID_v20beta1) + stream, err := wf.h.NewStream(ctx, subscriber.peer, FilterID_v20beta1) if err != nil { wf.subscribers.FlagAsFailure(subscriber.peer) logger.Error("opening peer stream", zap.Error(err)) @@ -169,16 +176,20 @@ func (wf *WakuFilter) pushMessage(ctx context.Context, subscriber Subscriber, ms return err } - defer conn.Close() - writer := pbio.NewDelimitedWriter(conn) + writer := pbio.NewDelimitedWriter(stream) err = writer.WriteMsg(pushRPC) if err != nil { logger.Error("pushing messages to peer", zap.Error(err)) wf.subscribers.FlagAsFailure(subscriber.peer) wf.metrics.RecordError(pushWriteError) + if err := stream.Reset(); err != nil { + wf.log.Error("resetting connection", zap.Error(err)) + } return nil } + stream.Close() + wf.subscribers.FlagAsSuccess(subscriber.peer) return nil } @@ -266,28 +277,30 @@ func (wf *WakuFilter) requestSubscription(ctx context.Context, filter ContentFil ContentFilters: contentFilters, } - var conn network.Stream - conn, err = wf.h.NewStream(ctx, params.selectedPeer, FilterID_v20beta1) + stream, err := wf.h.NewStream(ctx, params.selectedPeer, FilterID_v20beta1) if err != nil { wf.metrics.RecordError(dialFailure) return } - defer conn.Close() - // This is the only successful path to subscription requestID := hex.EncodeToString(protocol.GenerateRequestID()) - writer := pbio.NewDelimitedWriter(conn) + writer := pbio.NewDelimitedWriter(stream) filterRPC := &pb.FilterRPC{RequestId: requestID, Request: request} wf.log.Debug("sending filterRPC", zap.Stringer("rpc", filterRPC)) err = writer.WriteMsg(filterRPC) if err != nil { wf.metrics.RecordError(writeRequestFailure) wf.log.Error("sending filterRPC", zap.Error(err)) + if err := stream.Reset(); err != nil { + wf.log.Error("resetting connection", zap.Error(err)) + } return } + stream.Close() + subscription = new(FilterSubscription) subscription.Peer = params.selectedPeer subscription.RequestID = requestID @@ -297,15 +310,12 @@ func (wf *WakuFilter) requestSubscription(ctx context.Context, filter ContentFil // Unsubscribe is used to stop receiving messages from a peer that match a content filter func (wf *WakuFilter) Unsubscribe(ctx context.Context, contentFilter ContentFilter, peer peer.ID) error { - - conn, err := wf.h.NewStream(ctx, peer, FilterID_v20beta1) + stream, err := wf.h.NewStream(ctx, peer, FilterID_v20beta1) if err != nil { wf.metrics.RecordError(dialFailure) return err } - defer conn.Close() - // This is the only successful path to subscription id := protocol.GenerateRequestID() @@ -320,14 +330,19 @@ func (wf *WakuFilter) Unsubscribe(ctx context.Context, contentFilter ContentFilt ContentFilters: contentFilters, } - writer := pbio.NewDelimitedWriter(conn) + writer := pbio.NewDelimitedWriter(stream) filterRPC := &pb.FilterRPC{RequestId: hex.EncodeToString(id), Request: request} err = writer.WriteMsg(filterRPC) if err != nil { wf.metrics.RecordError(writeRequestFailure) + if err := stream.Reset(); err != nil { + wf.log.Error("resetting connection", zap.Error(err)) + } return err } + stream.Close() + return nil } diff --git a/waku/v2/protocol/lightpush/waku_lightpush.go b/waku/v2/protocol/lightpush/waku_lightpush.go index 0ac60deee..931de4fb8 100644 --- a/waku/v2/protocol/lightpush/waku_lightpush.go +++ b/waku/v2/protocol/lightpush/waku_lightpush.go @@ -76,23 +76,29 @@ func (wakuLP *WakuLightPush) relayIsNotAvailable() bool { return wakuLP.relay == nil } -func (wakuLP *WakuLightPush) onRequest(ctx context.Context) func(s network.Stream) { - return func(s network.Stream) { - defer s.Close() - logger := wakuLP.log.With(logging.HostID("peer", s.Conn().RemotePeer())) +func (wakuLP *WakuLightPush) onRequest(ctx context.Context) func(network.Stream) { + return func(stream network.Stream) { + logger := wakuLP.log.With(logging.HostID("peer", stream.Conn().RemotePeer())) requestPushRPC := &pb.PushRPC{} - writer := pbio.NewDelimitedWriter(s) - reader := pbio.NewDelimitedReader(s, math.MaxInt32) + writer := pbio.NewDelimitedWriter(stream) + reader := pbio.NewDelimitedReader(stream, math.MaxInt32) err := reader.ReadMsg(requestPushRPC) if err != nil { logger.Error("reading request", zap.Error(err)) wakuLP.metrics.RecordError(decodeRPCFailure) + if err := stream.Reset(); err != nil { + wakuLP.log.Error("resetting connection", zap.Error(err)) + } return } - logger.Info("request received") + logger = logger.With(zap.String("requestID", requestPushRPC.RequestId)) + + responsePushRPC := &pb.PushRPC{} + responsePushRPC.RequestId = requestPushRPC.RequestId + if requestPushRPC.Query != nil { logger.Info("push request") response := new(pb.PushResponse) @@ -113,23 +119,28 @@ func (wakuLP *WakuLightPush) onRequest(ctx context.Context) func(s network.Strea response.Info = "Could not publish message" } else { response.IsSuccess = true - response.Info = "Totally" // TODO: ask about this + response.Info = "OK" } - responsePushRPC := &pb.PushRPC{} - responsePushRPC.RequestId = requestPushRPC.RequestId responsePushRPC.Response = response err = writer.WriteMsg(responsePushRPC) if err != nil { wakuLP.metrics.RecordError(writeResponseFailure) logger.Error("writing response", zap.Error(err)) - _ = s.Reset() - } else { - logger.Info("response sent") + if err := stream.Reset(); err != nil { + wakuLP.log.Error("resetting connection", zap.Error(err)) + } + return } + + logger.Info("response sent") + stream.Close() } else { wakuLP.metrics.RecordError(emptyRequestBodyFailure) + if err := stream.Reset(); err != nil { + wakuLP.log.Error("resetting connection", zap.Error(err)) + } } if requestPushRPC.Response != nil { @@ -162,16 +173,6 @@ func (wakuLP *WakuLightPush) request(ctx context.Context, req *pb.PushRequest, p wakuLP.metrics.RecordError(dialFailure) return nil, err } - - defer stream.Close() - defer func() { - err := stream.Reset() - if err != nil { - wakuLP.metrics.RecordError(dialFailure) - logger.Error("resetting connection", zap.Error(err)) - } - }() - pushRequestRPC := &pb.PushRPC{RequestId: hex.EncodeToString(params.requestID), Query: req} writer := pbio.NewDelimitedWriter(stream) @@ -181,6 +182,9 @@ func (wakuLP *WakuLightPush) request(ctx context.Context, req *pb.PushRequest, p if err != nil { wakuLP.metrics.RecordError(writeRequestFailure) logger.Error("writing request", zap.Error(err)) + if err := stream.Reset(); err != nil { + wakuLP.log.Error("resetting connection", zap.Error(err)) + } return nil, err } @@ -189,9 +193,14 @@ func (wakuLP *WakuLightPush) request(ctx context.Context, req *pb.PushRequest, p if err != nil { logger.Error("reading response", zap.Error(err)) wakuLP.metrics.RecordError(decodeRPCFailure) + if err := stream.Reset(); err != nil { + wakuLP.log.Error("resetting connection", zap.Error(err)) + } return nil, err } + stream.Close() + return pushResponseRPC.Response, nil } diff --git a/waku/v2/protocol/peer_exchange/client.go b/waku/v2/protocol/peer_exchange/client.go index 69a607dc1..f1ba7c393 100644 --- a/waku/v2/protocol/peer_exchange/client.go +++ b/waku/v2/protocol/peer_exchange/client.go @@ -57,11 +57,13 @@ func (wakuPX *WakuPeerExchange) Request(ctx context.Context, numPeers int, opts if err != nil { return err } - defer stream.Close() writer := pbio.NewDelimitedWriter(stream) err = writer.WriteMsg(requestRPC) if err != nil { + if err := stream.Reset(); err != nil { + wakuPX.log.Error("resetting connection", zap.Error(err)) + } return err } @@ -69,9 +71,14 @@ func (wakuPX *WakuPeerExchange) Request(ctx context.Context, numPeers int, opts responseRPC := &pb.PeerExchangeRPC{} err = reader.ReadMsg(responseRPC) if err != nil { + if err := stream.Reset(); err != nil { + wakuPX.log.Error("resetting connection", zap.Error(err)) + } return err } + stream.Close() + return wakuPX.handleResponse(ctx, responseRPC.Response) } diff --git a/waku/v2/protocol/peer_exchange/protocol.go b/waku/v2/protocol/peer_exchange/protocol.go index af9f47e85..8230abaaa 100644 --- a/waku/v2/protocol/peer_exchange/protocol.go +++ b/waku/v2/protocol/peer_exchange/protocol.go @@ -87,16 +87,18 @@ func (wakuPX *WakuPeerExchange) start() error { return nil } -func (wakuPX *WakuPeerExchange) onRequest() func(s network.Stream) { - return func(s network.Stream) { - defer s.Close() - logger := wakuPX.log.With(logging.HostID("peer", s.Conn().RemotePeer())) +func (wakuPX *WakuPeerExchange) onRequest() func(network.Stream) { + return func(stream network.Stream) { + logger := wakuPX.log.With(logging.HostID("peer", stream.Conn().RemotePeer())) requestRPC := &pb.PeerExchangeRPC{} - reader := pbio.NewDelimitedReader(s, math.MaxInt32) + reader := pbio.NewDelimitedReader(stream, math.MaxInt32) err := reader.ReadMsg(requestRPC) if err != nil { logger.Error("reading request", zap.Error(err)) wakuPX.metrics.RecordError(decodeRPCFailure) + if err := stream.Reset(); err != nil { + wakuPX.log.Error("resetting connection", zap.Error(err)) + } return } @@ -114,14 +116,19 @@ func (wakuPX *WakuPeerExchange) onRequest() func(s network.Stream) { responseRPC.Response = new(pb.PeerExchangeResponse) responseRPC.Response.PeerInfos = records - writer := pbio.NewDelimitedWriter(s) + writer := pbio.NewDelimitedWriter(stream) err = writer.WriteMsg(responseRPC) if err != nil { logger.Error("writing response", zap.Error(err)) wakuPX.metrics.RecordError(pxFailure) + if err := stream.Reset(); err != nil { + wakuPX.log.Error("resetting connection", zap.Error(err)) + } return } } + + stream.Close() } } diff --git a/waku/v2/protocol/store/waku_store_client.go b/waku/v2/protocol/store/waku_store_client.go index 310417435..3d90d4e39 100644 --- a/waku/v2/protocol/store/waku_store_client.go +++ b/waku/v2/protocol/store/waku_store_client.go @@ -181,11 +181,6 @@ func (store *WakuStore) queryFrom(ctx context.Context, q *pb.HistoryQuery, selec return nil, err } - defer stream.Close() - defer func() { - _ = stream.Reset() - }() - historyRequest := &pb.HistoryRPC{Query: q, RequestId: hex.EncodeToString(requestID)} writer := pbio.NewDelimitedWriter(stream) @@ -195,6 +190,9 @@ func (store *WakuStore) queryFrom(ctx context.Context, q *pb.HistoryQuery, selec if err != nil { logger.Error("writing request", zap.Error(err)) store.metrics.RecordError(writeRequestFailure) + if err := stream.Reset(); err != nil { + store.log.Error("resetting connection", zap.Error(err)) + } return nil, err } @@ -203,9 +201,14 @@ func (store *WakuStore) queryFrom(ctx context.Context, q *pb.HistoryQuery, selec if err != nil { logger.Error("reading response", zap.Error(err)) store.metrics.RecordError(decodeRPCFailure) + if err := stream.Reset(); err != nil { + store.log.Error("resetting connection", zap.Error(err)) + } return nil, err } + stream.Close() + if historyResponseRPC.Response == nil { // Empty response return &pb.HistoryResponse{ diff --git a/waku/v2/protocol/store/waku_store_protocol.go b/waku/v2/protocol/store/waku_store_protocol.go index afb5c5a42..6b5d0b8cc 100644 --- a/waku/v2/protocol/store/waku_store_protocol.go +++ b/waku/v2/protocol/store/waku_store_protocol.go @@ -164,18 +164,20 @@ func (store *WakuStore) storeIncomingMessages(ctx context.Context) { } } -func (store *WakuStore) onRequest(s network.Stream) { - defer s.Close() - logger := store.log.With(logging.HostID("peer", s.Conn().RemotePeer())) +func (store *WakuStore) onRequest(stream network.Stream) { + logger := store.log.With(logging.HostID("peer", stream.Conn().RemotePeer())) historyRPCRequest := &pb.HistoryRPC{} - writer := pbio.NewDelimitedWriter(s) - reader := pbio.NewDelimitedReader(s, math.MaxInt32) + writer := pbio.NewDelimitedWriter(stream) + reader := pbio.NewDelimitedReader(stream, math.MaxInt32) err := reader.ReadMsg(historyRPCRequest) if err != nil { logger.Error("reading request", zap.Error(err)) store.metrics.RecordError(decodeRPCFailure) + if err := stream.Reset(); err != nil { + store.log.Error("resetting connection", zap.Error(err)) + } return } @@ -185,6 +187,9 @@ func (store *WakuStore) onRequest(s network.Stream) { } else { logger.Error("reading request", zap.Error(err)) store.metrics.RecordError(emptyRPCQueryFailure) + if err := stream.Reset(); err != nil { + store.log.Error("resetting connection", zap.Error(err)) + } return } @@ -200,13 +205,15 @@ func (store *WakuStore) onRequest(s network.Stream) { if err != nil { logger.Error("writing response", zap.Error(err), logging.PagingInfo(historyResponseRPC.Response.PagingInfo)) store.metrics.RecordError(writeResponseFailure) - _ = s.Reset() - } else { - logger.Info("response sent") + if err := stream.Reset(); err != nil { + store.log.Error("resetting connection", zap.Error(err)) + } + return } -} -// TODO: queryWithAccounting + logger.Info("response sent") + stream.Close() +} // Stop closes the store message channel and removes the protocol stream handler func (store *WakuStore) Stop() {