Skip to content

Commit

Permalink
refactor: validate protobuffers for lightpush and relay (#824)
Browse files Browse the repository at this point in the history
  • Loading branch information
richard-ramos committed Oct 24, 2023
1 parent 7415c5d commit 83c8d18
Show file tree
Hide file tree
Showing 9 changed files with 206 additions and 65 deletions.
9 changes: 7 additions & 2 deletions waku/v2/node/wakunode2_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ func createTestMsg(version uint32) *pb.WakuMessage {
message.Payload = []byte{0, 1, 2}
message.Version = version
message.Timestamp = 123456
message.ContentTopic = "abc"
return message
}

Expand Down Expand Up @@ -264,7 +265,8 @@ func TestDecoupledStoreFromRelay(t *testing.T) {
time.Sleep(2 * time.Second)

_, filter, err := wakuNode2.LegacyFilter().Subscribe(ctx, legacy_filter.ContentFilter{
Topic: string(relay.DefaultWakuTopic),
Topic: string(relay.DefaultWakuTopic),
ContentTopics: []string{"abc"},
}, legacy_filter.WithPeer(wakuNode1.host.ID()))
require.NoError(t, err)

Expand All @@ -281,7 +283,10 @@ func TestDecoupledStoreFromRelay(t *testing.T) {
go func() {
// MSG1 should be pushed in NODE2 via filter
defer wg.Done()
env := <-filter.Chan
env, ok := <-filter.Chan
if !ok {
require.Fail(t, "no message")
}
require.Equal(t, msg.Timestamp, env.Message().Timestamp)
}()

Expand Down
4 changes: 0 additions & 4 deletions waku/v2/protocol/legacy_filter/filter_map.go
Original file line number Diff line number Diff line change
Expand Up @@ -104,10 +104,6 @@ func (fm *FilterMap) Notify(msg *pb.WakuMessage, requestID string) {
// Broadcasting message so it's stored
fm.broadcaster.Submit(envelope)

if msg.ContentTopic == "" {
filter.Chan <- envelope
}

// TODO: In case of no topics we should either trigger here for all messages,
// or we should not allow such filter to exist in the first place.
for _, contentTopic := range filter.ContentFilters {
Expand Down
16 changes: 8 additions & 8 deletions waku/v2/protocol/lightpush/metrics.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,14 +49,14 @@ func (m *metricsImpl) RecordMessage() {
type metricsErrCategory string

var (
decodeRPCFailure metricsErrCategory = "decode_rpc_failure"
writeRequestFailure metricsErrCategory = "write_request_failure"
writeResponseFailure metricsErrCategory = "write_response_failure"
dialFailure metricsErrCategory = "dial_failure"
messagePushFailure metricsErrCategory = "message_push_failure"
emptyRequestBodyFailure metricsErrCategory = "empty_request_body_failure"
emptyResponseBodyFailure metricsErrCategory = "empty_response_body_failure"
peerNotFoundFailure metricsErrCategory = "peer_not_found_failure"
decodeRPCFailure metricsErrCategory = "decode_rpc_failure"
writeRequestFailure metricsErrCategory = "write_request_failure"
writeResponseFailure metricsErrCategory = "write_response_failure"
dialFailure metricsErrCategory = "dial_failure"
messagePushFailure metricsErrCategory = "message_push_failure"
requestBodyFailure metricsErrCategory = "request_failure"
responseBodyFailure metricsErrCategory = "response_body_failure"
peerNotFoundFailure metricsErrCategory = "peer_not_found_failure"
)

// RecordError increases the counter for different error types
Expand Down
48 changes: 48 additions & 0 deletions waku/v2/protocol/lightpush/pb/validation.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
package pb

import "errors"

var (
errMissingRequestID = errors.New("missing RequestId field")
errMissingQuery = errors.New("missing Query field")
errMissingMessage = errors.New("missing Message field")
errMissingPubsubTopic = errors.New("missing PubsubTopic field")
errRequestIDMismatch = errors.New("RequestID in response does not match request")
errMissingResponse = errors.New("missing Response field")
)

func (x *PushRPC) ValidateRequest() error {
if x.RequestId == "" {
return errMissingRequestID
}

if x.Query == nil {
return errMissingQuery
}

if x.Query.PubsubTopic == "" {
return errMissingPubsubTopic
}

if x.Query.Message == nil {
return errMissingMessage
}

return x.Query.Message.Validate()
}

func (x *PushRPC) ValidateResponse(requestID string) error {
if x.RequestId == "" {
return errMissingRequestID
}

if x.RequestId != requestID {
return errRequestIDMismatch
}

if x.Response == nil {
return errMissingResponse
}

return nil
}
35 changes: 35 additions & 0 deletions waku/v2/protocol/lightpush/pb/validation_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
package pb

import (
"testing"

"github.com/stretchr/testify/require"
"github.com/waku-org/go-waku/waku/v2/protocol/pb"
)

func TestValidateRequest(t *testing.T) {
request := PushRPC{}
require.ErrorIs(t, request.ValidateRequest(), errMissingRequestID)
request.RequestId = "test"
require.ErrorIs(t, request.ValidateRequest(), errMissingQuery)
request.Query = &PushRequest{}
require.ErrorIs(t, request.ValidateRequest(), errMissingPubsubTopic)
request.Query.PubsubTopic = "test"
require.ErrorIs(t, request.ValidateRequest(), errMissingMessage)
request.Query.Message = &pb.WakuMessage{
Payload: []byte{1, 2, 3},
ContentTopic: "test",
}
require.NoError(t, request.ValidateRequest())
}

func TestValidateResponse(t *testing.T) {
response := PushRPC{}
require.ErrorIs(t, response.ValidateResponse("test"), errMissingRequestID)
response.RequestId = "test1"
require.ErrorIs(t, response.ValidateResponse("test"), errRequestIDMismatch)
response.RequestId = "test"
require.ErrorIs(t, response.ValidateResponse("test"), errMissingResponse)
response.Response = &PushResponse{}
require.NoError(t, response.ValidateResponse("test"))
}
97 changes: 52 additions & 45 deletions waku/v2/protocol/lightpush/waku_lightpush.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"encoding/hex"
"errors"
"fmt"
"math"

"github.com/libp2p/go-libp2p/core/host"
Expand Down Expand Up @@ -81,7 +82,6 @@ func (wakuLP *WakuLightPush) onRequest(ctx context.Context) func(network.Stream)
logger := wakuLP.log.With(logging.HostID("peer", stream.Conn().RemotePeer()))
requestPushRPC := &pb.PushRPC{}

writer := pbio.NewDelimitedWriter(stream)
reader := pbio.NewDelimitedReader(stream, math.MaxInt32)

err := reader.ReadMsg(requestPushRPC)
Expand All @@ -94,65 +94,67 @@ func (wakuLP *WakuLightPush) onRequest(ctx context.Context) func(network.Stream)
return
}

logger = logger.With(zap.String("requestID", requestPushRPC.RequestId))
responsePushRPC := &pb.PushRPC{
RequestId: requestPushRPC.RequestId,
Response: &pb.PushResponse{},
}

responsePushRPC := &pb.PushRPC{}
responsePushRPC.RequestId = requestPushRPC.RequestId
if err := requestPushRPC.ValidateRequest(); err != nil {
responsePushRPC.Response.Info = err.Error()
wakuLP.metrics.RecordError(requestBodyFailure)
wakuLP.reply(stream, responsePushRPC, logger)
return
}

if requestPushRPC.Query != nil {
logger.Info("push request")
response := new(pb.PushResponse)
logger = logger.With(zap.String("requestID", requestPushRPC.RequestId))

pubSubTopic := requestPushRPC.Query.PubsubTopic
message := requestPushRPC.Query.Message
logger.Info("push request")

wakuLP.metrics.RecordMessage()
pubSubTopic := requestPushRPC.Query.PubsubTopic
message := requestPushRPC.Query.Message

// TODO: Assumes success, should probably be extended to check for network, peers, etc
// It might make sense to use WithReadiness option here?
wakuLP.metrics.RecordMessage()

_, err := wakuLP.relay.PublishToTopic(ctx, message, pubSubTopic)
// TODO: Assumes success, should probably be extended to check for network, peers, etc
// It might make sense to use WithReadiness option here?

if err != nil {
logger.Error("publishing message", zap.Error(err))
wakuLP.metrics.RecordError(messagePushFailure)
response.Info = "Could not publish message"
} else {
response.IsSuccess = true
response.Info = "OK"
}
_, err = wakuLP.relay.PublishToTopic(ctx, message, pubSubTopic)
if err != nil {
logger.Error("publishing message", zap.Error(err))
wakuLP.metrics.RecordError(messagePushFailure)
responsePushRPC.Response.Info = fmt.Sprintf("Could not publish message: %s", err.Error())
return
} else {
responsePushRPC.Response.IsSuccess = true
responsePushRPC.Response.Info = "OK"
}

responsePushRPC.Response = response
wakuLP.reply(stream, responsePushRPC, logger)

err = writer.WriteMsg(responsePushRPC)
if err != nil {
wakuLP.metrics.RecordError(writeResponseFailure)
logger.Error("writing response", zap.Error(err))
if err := stream.Reset(); err != nil {
wakuLP.log.Error("resetting connection", zap.Error(err))
}
return
}
logger.Info("response sent")

logger.Info("response sent")
stream.Close()
stream.Close()

if responsePushRPC.Response.IsSuccess {
logger.Info("request success")
} else {
wakuLP.metrics.RecordError(emptyRequestBodyFailure)
if err := stream.Reset(); err != nil {
wakuLP.log.Error("resetting connection", zap.Error(err))
}
logger.Info("request failure", zap.String("info", responsePushRPC.Response.Info))
}
}
}

if requestPushRPC.Response != nil {
if requestPushRPC.Response.IsSuccess {
logger.Info("request success")
} else {
logger.Info("request failure", zap.String("info=", requestPushRPC.Response.Info))
}
} else {
wakuLP.metrics.RecordError(emptyResponseBodyFailure)
func (wakuLP *WakuLightPush) reply(stream network.Stream, responsePushRPC *pb.PushRPC, logger *zap.Logger) {
writer := pbio.NewDelimitedWriter(stream)
err := writer.WriteMsg(responsePushRPC)
if err != nil {
wakuLP.metrics.RecordError(writeResponseFailure)
logger.Error("writing response", zap.Error(err))
if err := stream.Reset(); err != nil {
wakuLP.log.Error("resetting connection", zap.Error(err))
}
return
}
stream.Close()
}

// request sends a message via lightPush protocol to either a specified peer or peer that is selected.
Expand Down Expand Up @@ -201,6 +203,11 @@ func (wakuLP *WakuLightPush) request(ctx context.Context, req *pb.PushRequest, p

stream.Close()

if err = pushResponseRPC.ValidateResponse(pushRequestRPC.RequestId); err != nil {
wakuLP.metrics.RecordError(responseBodyFailure)
return nil, err
}

return pushResponseRPC.Response, nil
}

Expand Down
47 changes: 47 additions & 0 deletions waku/v2/protocol/pb/validation.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
package pb

import (
"errors"

"google.golang.org/protobuf/proto"
)

const MaxMetaAttrLength = 64

var (
errMissingPayload = errors.New("missing Payload field")
errMissingContentTopic = errors.New("missing ContentTopic field")
errInvalidMetaLength = errors.New("invalid length for Meta field")
)

func (msg *WakuMessage) Validate() error {
if len(msg.Payload) == 0 {
return errMissingPayload
}

if msg.ContentTopic == "" {
return errMissingContentTopic
}

if len(msg.Meta) > MaxMetaAttrLength {
return errInvalidMetaLength
}

return nil
}

func Unmarshal(data []byte) (*WakuMessage, error) {
msg := &WakuMessage{}
err := proto.Unmarshal(data, msg)
if err != nil {
return nil, err
}

err = msg.Validate()
if err != nil {
return nil, err
}

return msg, nil

}
4 changes: 1 addition & 3 deletions waku/v2/protocol/relay/validators.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ import (
"github.com/ethereum/go-ethereum/crypto/secp256k1"
pubsub "github.com/libp2p/go-libp2p-pubsub"
"github.com/libp2p/go-libp2p/core/peer"
proto "google.golang.org/protobuf/proto"

"github.com/waku-org/go-waku/waku/v2/hash"
"github.com/waku-org/go-waku/waku/v2/protocol/pb"
Expand Down Expand Up @@ -62,8 +61,7 @@ func (w *WakuRelay) RemoveTopicValidator(topic string) {

func (w *WakuRelay) topicValidator(topic string) func(ctx context.Context, peerID peer.ID, message *pubsub.Message) bool {
return func(ctx context.Context, peerID peer.ID, message *pubsub.Message) bool {
msg := new(pb.WakuMessage)
err := proto.Unmarshal(message.Data, msg)
msg, err := pb.Unmarshal(message.Data)
if err != nil {
return false
}
Expand Down
11 changes: 8 additions & 3 deletions waku/v2/protocol/relay/waku_relay.go
Original file line number Diff line number Diff line change
Expand Up @@ -234,12 +234,15 @@ func (w *WakuRelay) PublishToTopic(ctx context.Context, message *pb.WakuMessage,
return nil, errors.New("message can't be null")
}

if err := message.Validate(); err != nil {
return nil, err
}

if !w.EnoughPeersToPublishToTopic(topic) {
return nil, errors.New("not enough peers to publish")
}

pubSubTopic, err := w.upsertTopic(topic)

if err != nil {
return nil, err
}
Expand Down Expand Up @@ -460,11 +463,13 @@ func (w *WakuRelay) pubsubTopicMsgHandler(pubsubTopic string, sub *pubsub.Subscr
sub.Cancel()
return
}
wakuMessage := &pb.WakuMessage{}
if err := proto.Unmarshal(msg.Data, wakuMessage); err != nil {

wakuMessage, err := pb.Unmarshal(msg.Data)
if err != nil {
w.log.Error("decoding message", zap.Error(err))
return
}

envelope := waku_proto.NewEnvelope(wakuMessage, w.timesource.Now().UnixNano(), pubsubTopic)
w.metrics.RecordMessage(envelope)

Expand Down

0 comments on commit 83c8d18

Please sign in to comment.