Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow caller to optionally validate messages #45

Closed
wants to merge 11 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ install: true
script:
- make deps
- go vet
- go test ./...
- go test -v ./...

cache:
directories:
Expand Down
226 changes: 190 additions & 36 deletions floodsub.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,11 @@ import (
timecache "github.com/whyrusleeping/timecache"
)

const ID = protocol.ID("/floodsub/1.0.0")
const (
ID = protocol.ID("/floodsub/1.0.0")
defaultMaxConcurrency = 10
Copy link
Collaborator

@vyzo vyzo Jan 7, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's make this a 100 at least (or even 1000) -- it's a global throttle on all messages not just per topic.

defaultValidateTimeout = 150 * time.Millisecond
)

var log = logging.Logger("floodsub")

Expand Down Expand Up @@ -53,6 +57,12 @@ type PubSub struct {
// topics tracks which topics each of our peers are subscribed to
topics map[string]map[peer.ID]struct{}

// sendMsg handles messages that have been validated
sendMsg chan sendReq

// throttleValidate bounds the number of goroutines concurrently validating messages
throttleValidate chan struct{}

peers map[peer.ID]chan *RPC
seenMessages *timecache.TimeCache

Expand All @@ -74,31 +84,49 @@ type RPC struct {
from peer.ID
}

type Option func(*PubSub) error

func WithMaxConcurrency(n int) Option {
return func(ps *PubSub) error {
ps.throttleValidate = make(chan struct{}, n)
return nil
}
}

// NewFloodSub returns a new FloodSub management object
func NewFloodSub(ctx context.Context, h host.Host) *PubSub {
func NewFloodSub(ctx context.Context, h host.Host, opts ...Option) (*PubSub, error) {
ps := &PubSub{
host: h,
ctx: ctx,
incoming: make(chan *RPC, 32),
publish: make(chan *Message),
newPeers: make(chan inet.Stream),
peerDead: make(chan peer.ID),
cancelCh: make(chan *Subscription),
getPeers: make(chan *listPeerReq),
addSub: make(chan *addSubReq),
getTopics: make(chan *topicReq),
myTopics: make(map[string]map[*Subscription]struct{}),
topics: make(map[string]map[peer.ID]struct{}),
peers: make(map[peer.ID]chan *RPC),
seenMessages: timecache.NewTimeCache(time.Second * 30),
host: h,
ctx: ctx,
incoming: make(chan *RPC, 32),
publish: make(chan *Message),
newPeers: make(chan inet.Stream),
peerDead: make(chan peer.ID),
cancelCh: make(chan *Subscription),
getPeers: make(chan *listPeerReq),
addSub: make(chan *addSubReq),
getTopics: make(chan *topicReq),
sendMsg: make(chan sendReq),
myTopics: make(map[string]map[*Subscription]struct{}),
topics: make(map[string]map[peer.ID]struct{}),
peers: make(map[peer.ID]chan *RPC),
seenMessages: timecache.NewTimeCache(time.Second * 30),
throttleValidate: make(chan struct{}, defaultMaxConcurrency),
}

for _, opt := range opts {
err := opt(ps)
if err != nil {
return nil, err
}
}

h.SetStreamHandler(ID, ps.handleNewStream)
h.Network().Notify((*PubSubNotif)(ps))

go ps.processLoop(ctx)

return ps
return ps, nil
}

// processLoop handles all inputs arriving on the channels
Expand Down Expand Up @@ -171,7 +199,27 @@ func (p *PubSub) processLoop(ctx context.Context) {
continue
}
case msg := <-p.publish:
p.maybePublishMessage(p.host.ID(), msg.Message)
subs := p.getSubscriptions(msg) // call before goroutine!

select {
case p.throttleValidate <- struct{}{}:
go func(msg *Message) {
defer func() { <-p.throttleValidate }()

if p.validate(subs, msg) {
p.sendMsg <- sendReq{
from: p.host.ID(),
msg: msg,
}

}
}(msg)
default:
log.Warning("could not acquire validator; dropping message")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps we should log some more information about the message? The topic and the message id come to mind.

}
case req := <-p.sendMsg:
p.maybePublishMessage(req.from, req.msg.Message)

case <-ctx.Done():
log.Info("pubsub processloop shutting down")
return
Expand Down Expand Up @@ -205,24 +253,22 @@ func (p *PubSub) handleRemoveSubscription(sub *Subscription) {
// subscribes to the topic.
// Only called from processLoop.
func (p *PubSub) handleAddSubscription(req *addSubReq) {
subs := p.myTopics[req.topic]
sub := req.sub
subs := p.myTopics[sub.topic]

// announce we want this topic
if len(subs) == 0 {
p.announce(req.topic, true)
p.announce(sub.topic, true)
}

// make new if not there
if subs == nil {
p.myTopics[req.topic] = make(map[*Subscription]struct{})
subs = p.myTopics[req.topic]
p.myTopics[sub.topic] = make(map[*Subscription]struct{})
subs = p.myTopics[sub.topic]
}

sub := &Subscription{
ch: make(chan *Message, 32),
topic: req.topic,
cancelCh: p.cancelCh,
}
sub.ch = make(chan *Message, 32)
sub.cancelCh = p.cancelCh

p.myTopics[sub.topic][sub] = struct{}{}

Expand Down Expand Up @@ -309,7 +355,23 @@ func (p *PubSub) handleIncomingRPC(rpc *RPC) error {
continue
}

p.maybePublishMessage(rpc.from, pmsg)
subs := p.getSubscriptions(&Message{pmsg}) // call before goroutine!

select {
case p.throttleValidate <- struct{}{}:
go func(pmsg *pb.Message) {
defer func() { <-p.throttleValidate }()

if p.validate(subs, &Message{pmsg}) {
p.sendMsg <- sendReq{
from: rpc.from,
msg: &Message{pmsg},
}
}
}(pmsg)
default:
log.Warning("could not acquire validator; dropping message")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe in the default here, we can have another select, just like this one, but instead of a default there, we have a timeout. That way we avoid starting a timer for every message, and we don't immediately drop any overflow

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wouldn't that block the main loop? And what would be the difference to just using the timeout instead of the default?

}
}
return nil
}
Expand All @@ -319,6 +381,43 @@ func msgID(pmsg *pb.Message) string {
return string(pmsg.GetFrom()) + string(pmsg.GetSeqno())
}

// validate is called in a goroutine and calls the validate functions of all subs with msg as parameter.
func (p *PubSub) validate(subs []*Subscription, msg *Message) bool {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this validates for all subscriptions -- but there may be a case that the message is valid for some topic and invalid for some other topic. Shouldn't we publish it for the valid topic?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I talked to why about this and we weren't sure what to do. For now we just wanted to drop it. Maybe the right thing is to remove the topics it fails for and one drop the message if all topics have been removed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reasoning was "we can always make things more complicated" :)

results := make([]chan bool, len(subs))
ctxs := make([]context.Context, len(subs))

for i, sub := range subs {
result := make(chan bool)
ctx, cancel := context.WithTimeout(p.ctx, sub.validateTimeout)
defer cancel()

ctxs[i] = ctx
results[i] = result

go func(sub *Subscription) {
result <- sub.validate == nil || sub.validate(ctx, msg)
}(sub)
}

for i, sub := range subs {
ctx := ctxs[i]
result := results[i]

select {
case valid := <-result:
if !valid {
log.Debugf("validator for topic %s returned false", sub.topic)
return false
}
case <-ctx.Done():
log.Debugf("validator for topic %s timed out. msg: %s", sub.topic, msg)
return false
}
}

return true
}

func (p *PubSub) maybePublishMessage(from peer.ID, pmsg *pb.Message) {
id := msgID(pmsg)
if p.seenMessage(id) {
Expand All @@ -343,7 +442,7 @@ func (p *PubSub) publishMessage(from peer.ID, msg *pb.Message) error {
continue
}

for p, _ := range tmap {
for p := range tmap {
tosend[p] = struct{}{}
}
}
Expand All @@ -370,20 +469,57 @@ func (p *PubSub) publishMessage(from peer.ID, msg *pb.Message) error {
return nil
}

// getSubscriptions returns all subscriptions the would receive the given message.
func (p *PubSub) getSubscriptions(msg *Message) []*Subscription {
var subs []*Subscription

for _, topic := range msg.GetTopicIDs() {
tSubs, ok := p.myTopics[topic]
if !ok {
continue
}

for sub := range tSubs {
subs = append(subs, sub)
}
}

return subs
}

type addSubReq struct {
topic string
resp chan *Subscription
sub *Subscription
resp chan *Subscription
}

type SubOpt func(*Subscription) error
type Validator func(context.Context, *Message) bool

// WithValidator is an option that can be supplied to Subscribe. The argument is a function that returns whether or not a given message should be propagated further.
func WithValidator(validate Validator) func(*Subscription) error {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we use SubOpt in the return type for this prototype?

return func(sub *Subscription) error {
sub.validate = validate
return nil
}
}

// WithValidatorTimeout is an option that can be supplied to Subscribe. The argument is a duration after which long-running validators are canceled.
func WithValidatorTimeout(timeout time.Duration) func(*Subscription) error {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here too

return func(sub *Subscription) error {
sub.validateTimeout = timeout
return nil
}
}

// Subscribe returns a new Subscription for the given topic
func (p *PubSub) Subscribe(topic string) (*Subscription, error) {
func (p *PubSub) Subscribe(topic string, opts ...SubOpt) (*Subscription, error) {
td := pb.TopicDescriptor{Name: &topic}

return p.SubscribeByTopicDescriptor(&td)
return p.SubscribeByTopicDescriptor(&td, opts...)
}

// SubscribeByTopicDescriptor lets you subscribe a topic using a pb.TopicDescriptor
func (p *PubSub) SubscribeByTopicDescriptor(td *pb.TopicDescriptor) (*Subscription, error) {
func (p *PubSub) SubscribeByTopicDescriptor(td *pb.TopicDescriptor, opts ...SubOpt) (*Subscription, error) {
if td.GetAuth().GetMode() != pb.TopicDescriptor_AuthOpts_NONE {
return nil, fmt.Errorf("auth mode not yet supported")
}
Expand All @@ -392,10 +528,22 @@ func (p *PubSub) SubscribeByTopicDescriptor(td *pb.TopicDescriptor) (*Subscripti
return nil, fmt.Errorf("encryption mode not yet supported")
}

sub := &Subscription{
topic: td.GetName(),
validateTimeout: defaultValidateTimeout,
}

for _, opt := range opts {
err := opt(sub)
if err != nil {
return nil, err
}
}

out := make(chan *Subscription, 1)
p.addSub <- &addSubReq{
topic: td.GetName(),
resp: out,
sub: sub,
resp: out,
}

return <-out, nil
Expand Down Expand Up @@ -433,6 +581,12 @@ type listPeerReq struct {
topic string
}

// sendReq is a request to call maybePublishMessage. It is issued after the subscription verification is done.
type sendReq struct {
from peer.ID
msg *Message
}

// ListPeers returns a list of peers we are connected to.
func (p *PubSub) ListPeers(topic string) []peer.ID {
out := make(chan []peer.ID)
Expand Down
Loading