Skip to content
This repository has been archived by the owner on Aug 2, 2021. It is now read-only.

Commit

Permalink
swarm/pss: Create outbox queue for pss
Browse files Browse the repository at this point in the history
Implements a queue manager to enable resending when forwarding fails.
Messages are not forwarded right away, but put in a queue which is in
turn fetched by a loop started when the service starts.

swarm/pss: WIP outbox

swarm/pss: Add read mutexes

swarm/pss: Implement queue as channel

swarm/pss: Remove commented code
  • Loading branch information
nolash committed Apr 2, 2018
1 parent 17977c2 commit 8645ead
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 47 deletions.
2 changes: 1 addition & 1 deletion swarm/pss/handshake.go
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,7 @@ func (self *HandshakeController) handleKeys(pubkeyid string, keymsg *handshakeMs
copy(sendsymkey, key)
var address PssAddress
copy(address[:], keymsg.From)
sendsymkeyid, err := self.pss.SetSymmetricKey(sendsymkey, keymsg.Topic, &address, false)
sendsymkeyid, err := self.pss.setSymmetricKey(sendsymkey, keymsg.Topic, &address, false, false)
if err != nil {
return err
}
Expand Down
103 changes: 59 additions & 44 deletions swarm/pss/pss.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import (
"bytes"
"crypto/ecdsa"
"crypto/rand"
"errors"
"fmt"
"sync"
"time"
Expand All @@ -31,7 +30,9 @@ const (
defaultWhisperWorkTime = 3
defaultWhisperPoW = 0.0000000001
defaultMaxMsgSize = 1024 * 1024
defaultCleanInterval = 1000 * 60 * 10
defaultCleanInterval = time.Second * 60 * 10
defaultDequeueInterval = time.Millisecond * 10
defaultOutboxQueueSize = 10000
pssProtocolName = "pss"
pssVersion = 1
)
Expand Down Expand Up @@ -94,13 +95,14 @@ type Pss struct {

// sending and forwarding
fwdPool map[string]*protocols.Peer // keep track of all peers sitting on the pssmsg routing layer
fwdPoolMu sync.Mutex
fwdPoolMu sync.RWMutex
fwdCache map[pssDigest]pssCacheEntry // checksum of unique fields from pssmsg mapped to expiry, cache to determine whether to drop msg
fwdCacheMu sync.Mutex
fwdCacheMu sync.RWMutex
cacheTTL time.Duration // how long to keep messages in fwdCache (not implemented)
msgTTL time.Duration
paddingByteSize int
capstring string
outbox chan *PssMsg

// keys and peers
pubKeyPool map[string]map[Topic]*pssPeer // mapping of hex public keys to peer address by topic.
Expand All @@ -113,7 +115,7 @@ type Pss struct {

// message handling
handlers map[Topic]map[*Handler]bool // topic and version based pss payload handlers. See pss.Handle()
handlersMu sync.Mutex
handlersMu sync.RWMutex

// process
quitC chan struct{}
Expand Down Expand Up @@ -145,6 +147,7 @@ func NewPss(k network.Overlay, dpa *storage.DPA, params *PssParams) *Pss {
msgTTL: params.MsgTTL,
paddingByteSize: defaultPaddingByteSize,
capstring: cap.String(),
outbox: make(chan *PssMsg, defaultOutboxQueueSize),

pubKeyPool: make(map[string]map[Topic]*pssPeer),
symKeyPool: make(map[string]map[Topic]*pssPeer),
Expand All @@ -161,12 +164,24 @@ func NewPss(k network.Overlay, dpa *storage.DPA, params *PssParams) *Pss {

func (self *Pss) Start(srv *p2p.Server) error {
go func() {
tickC := time.Tick(defaultCleanInterval)
select {
case <-tickC:
self.cleanKeys()
case <-self.quitC:
log.Info("pss shutting down")
for {
tickC := time.Tick(defaultCleanInterval)
select {
case <-tickC:
self.cleanKeys()
case <-self.quitC:
log.Info("pss shutting down")
}
}
}()
go func() {
for {
select {
case msg := <-self.outbox:
self.forward(msg)
case <-self.quitC:
log.Info("pss shutting down")
}
}
}()
log.Debug("Started pss", "public key", common.ToHex(crypto.FromECDSAPub(self.PublicKey())))
Expand Down Expand Up @@ -272,8 +287,8 @@ func (self *Pss) deregister(topic *Topic, h *Handler) {

// get all registered handlers for respective topics
func (self *Pss) getHandlers(topic Topic) map[*Handler]bool {
self.handlersMu.Lock()
defer self.handlersMu.Unlock()
self.handlersMu.RLock()
defer self.handlersMu.RUnlock()
return self.handlers[topic]
}

Expand All @@ -286,20 +301,13 @@ func (self *Pss) handlePssMsg(msg interface{}) error {
if ok {
var err error
if !self.isSelfPossibleRecipient(pssmsg) {
msgexp := time.Unix(int64(pssmsg.Expire), 0)
if msgexp.Before(time.Now()) {
log.Trace("pss expired :/ ... dropping")
return nil
} else if msgexp.After(time.Now().Add(self.msgTTL)) {
return errors.New("Invalid TTL")
}
log.Trace("pss was for someone else :'( ... forwarding", "pss", common.ToHex(self.BaseAddr()))
return self.forward(pssmsg)
self.outbox <- pssmsg
}
log.Trace("pss for us, yay! ... let's process!", "pss", common.ToHex(self.BaseAddr()))

if !self.process(pssmsg) {
err = self.forward(pssmsg)
self.outbox <- pssmsg
}
return err
}
Expand Down Expand Up @@ -335,10 +343,7 @@ func (self *Pss) process(pssmsg *PssMsg) bool {

if len(pssmsg.To) < addressLength {
go func() {
err := self.forward(pssmsg)
if err != nil {
log.Warn("Redundant forward fail: %v", err)
}
self.outbox <- pssmsg
}()
}
handlers := self.getHandlers(psstopic)
Expand Down Expand Up @@ -401,7 +406,7 @@ func (self *Pss) generateSymmetricKey(topic Topic, address *PssAddress, addToCac
if err != nil {
return "", err
}
self.addSymmetricKeyToPool(keyid, topic, address, addToCache)
self.addSymmetricKeyToPool(keyid, topic, address, addToCache, false)
return keyid, nil
}

Expand All @@ -418,20 +423,25 @@ func (self *Pss) generateSymmetricKey(topic Topic, address *PssAddress, addToCac
// Returns a string id that can be used to retrieve the key bytes
// from the whisper backend (see pss.GetSymmetricKey())
func (self *Pss) SetSymmetricKey(key []byte, topic Topic, address *PssAddress, addtocache bool) (string, error) {
return self.setSymmetricKey(key, topic, address, addtocache, true)
}

func (self *Pss) setSymmetricKey(key []byte, topic Topic, address *PssAddress, addtocache bool, protected bool) (string, error) {
keyid, err := self.w.AddSymKeyDirect(key)
if err != nil {
return "", err
}
self.addSymmetricKeyToPool(keyid, topic, address, addtocache)
self.addSymmetricKeyToPool(keyid, topic, address, addtocache, protected)
return keyid, nil
}

// adds a symmetric key to the pss key pool, and optionally adds the key
// to the collection of keys used to attempt symmetric decryption of
// incoming messages
func (self *Pss) addSymmetricKeyToPool(keyid string, topic Topic, address *PssAddress, addtocache bool) {
func (self *Pss) addSymmetricKeyToPool(keyid string, topic Topic, address *PssAddress, addtocache bool, protected bool) {
psp := &pssPeer{
address: address,
address: address,
protected: protected,
}
self.symKeyPoolMu.Lock()
if _, ok := self.symKeyPool[keyid]; !ok {
Expand Down Expand Up @@ -644,29 +654,23 @@ func (self *Pss) send(to []byte, topic Topic, msg []byte, asymmetric bool, key [
Expire: uint32(time.Now().Add(self.msgTTL).Unix()),
Payload: envelope,
}
return self.forward(pssmsg)
self.outbox <- pssmsg
return nil
}

// Forwards a pss message to the peer(s) closest to the to recipient address in the PssMsg struct
// The recipient address can be of any length, and the byte slice will be matched to the MSB slice
// of the peer address of the equivalent length.
func (self *Pss) forward(msg *PssMsg) error {
func (self *Pss) forward(msg *PssMsg) {
to := make([]byte, addressLength)
copy(to[:len(msg.To)], msg.To)

// cache the message
// message hash
digest, err := self.storeMsg(msg)
if err != nil {
log.Warn(fmt.Sprintf("could not store message %v to cache: %v", msg, err))
}

// flood guard:
// don't allow identical messages we saw shortly before
if self.checkFwdCache(nil, digest) {
log.Trace(fmt.Sprintf("pss relay block-cache match: FROM %x TO %x", self.Overlay.BaseAddr(), common.ToHex(msg.To)))
return nil
}

// send with kademlia
// find the closest peer to the recipient and attempt to send
sent := 0
Expand Down Expand Up @@ -696,7 +700,9 @@ func (self *Pss) forward(msg *PssMsg) error {

// get the protocol peer from the forwarding peer cache
sendMsg := fmt.Sprintf("MSG %x TO %x FROM %x VIA %x", digest, to, self.BaseAddr(), op.Address())
self.fwdPoolMu.RLock()
pp := self.fwdPool[sp.Info().ID]
self.fwdPoolMu.RUnlock()
if self.checkFwdCache(op.Address(), digest) {
log.Trace(fmt.Sprintf("%v: peer already forwarded to", sendMsg))
return true
Expand Down Expand Up @@ -730,11 +736,12 @@ func (self *Pss) forward(msg *PssMsg) error {

if sent == 0 {
log.Debug("unable to forward to any peers")
return nil
time.Sleep(time.Millisecond)
self.outbox <- msg
}

// cache the message
self.addFwdCache(digest)
return nil
}

/////////////////////////////////////////////////////////////////////
Expand All @@ -757,8 +764,8 @@ func (self *Pss) addFwdCache(digest pssDigest) error {

// check if message is in the cache
func (self *Pss) checkFwdCache(addr []byte, digest pssDigest) bool {
self.fwdCacheMu.Lock()
defer self.fwdCacheMu.Unlock()
self.fwdCacheMu.RLock()
defer self.fwdCacheMu.RUnlock()
entry, ok := self.fwdCache[digest]
if ok {
if entry.expiresAt.After(time.Now()) {
Expand All @@ -785,3 +792,11 @@ func (self *Pss) storeMsg(msg *PssMsg) (pssDigest, error) {
copy(digest[:], key[:digestLength])
return digest, nil
}

func (self *Pss) isMsgExpired(msg *PssMsg) bool {
msgexp := time.Unix(int64(msg.Expire), 0)
if msgexp.Before(time.Now()) || msgexp.After(time.Now().Add(self.msgTTL)) {
return true
}
return false
}
4 changes: 2 additions & 2 deletions swarm/pss/pss_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -477,7 +477,7 @@ func testSymSend(t *testing.T) {
if err != nil {
t.Fatal(err)
}
err = clients[1].Call(&rkeyids, "psstest_setSymKeys", rpubkeyhex, rrecvkey, lrecvkey, defaultSymKeySendLimit, topic, loaddrhex)
err = clients[1].Call(&rkeyids, "psstest_setSymKeys", lpubkeyhex, rrecvkey, lrecvkey, defaultSymKeySendLimit, topic, loaddrhex)
if err != nil {
t.Fatal(err)
}
Expand All @@ -504,7 +504,7 @@ func testSymSend(t *testing.T) {
select {
case recvmsg := <-rmsgC:
if !bytes.Equal(recvmsg.Msg, rmsg) {
t.Fatalf("node 2 received payload mismatch: expected %v, got %v", rmsg, recvmsg.Msg)
t.Fatalf("node 2 received payload mismatch: expected %x, got %v", rmsg, recvmsg.Msg)
}
case cerr := <-rctx.Done():
t.Fatalf("test message timed out: %v", cerr)
Expand Down

0 comments on commit 8645ead

Please sign in to comment.