From ad30333581e9050f70c64b81cff0047671949ce4 Mon Sep 17 00:00:00 2001 From: Brian Tiger Chow Date: Thu, 11 Sep 2014 01:02:52 -0700 Subject: [PATCH] fix(bitswap:ledger) race conditions https://github.com/jbenet/go-ipfs/issues/39 --- bitswap/ledger.go | 23 +++++++++++++++++++++++ bitswap/ledger_test.go | 23 +++++++++++++++++++++++ 2 files changed, 46 insertions(+) create mode 100644 bitswap/ledger_test.go diff --git a/bitswap/ledger.go b/bitswap/ledger.go index b8b58e5a665..6ddc0a71107 100644 --- a/bitswap/ledger.go +++ b/bitswap/ledger.go @@ -1,6 +1,7 @@ package bitswap import ( + "sync" "time" peer "github.com/jbenet/go-ipfs/peer" @@ -9,6 +10,7 @@ import ( // Ledger stores the data exchange relationship between two peers. type Ledger struct { + lock sync.RWMutex // Partner is the remote Peer. Partner *peer.Peer @@ -35,16 +37,25 @@ type Ledger struct { type LedgerMap map[u.Key]*Ledger func (l *Ledger) ShouldSend() bool { + l.lock.Lock() + defer l.lock.Unlock() + return l.Strategy(l) } func (l *Ledger) SentBytes(n int) { + l.lock.Lock() + defer l.lock.Unlock() + l.exchangeCount++ l.lastExchange = time.Now() l.Accounting.BytesSent += uint64(n) } func (l *Ledger) ReceivedBytes(n int) { + l.lock.Lock() + defer l.lock.Unlock() + l.exchangeCount++ l.lastExchange = time.Now() l.Accounting.BytesRecv += uint64(n) @@ -52,10 +63,22 @@ func (l *Ledger) ReceivedBytes(n int) { // TODO: this needs to be different. We need timeouts. func (l *Ledger) Wants(k u.Key) { + l.lock.Lock() + defer l.lock.Unlock() + l.wantList[k] = struct{}{} } func (l *Ledger) WantListContains(k u.Key) bool { + l.lock.RLock() + defer l.lock.RUnlock() + _, ok := l.wantList[k] return ok } + +func (l *Ledger) ExchangeCount() uint64 { + l.lock.RLock() + defer l.lock.RUnlock() + return l.exchangeCount +} diff --git a/bitswap/ledger_test.go b/bitswap/ledger_test.go new file mode 100644 index 00000000000..d651d485ff7 --- /dev/null +++ b/bitswap/ledger_test.go @@ -0,0 +1,23 @@ +package bitswap + +import ( + "sync" + "testing" +) + +func TestRaceConditions(t *testing.T) { + const numberOfExpectedExchanges = 10000 + l := new(Ledger) + var wg sync.WaitGroup + for i := 0; i < numberOfExpectedExchanges; i++ { + wg.Add(1) + go func() { + defer wg.Done() + l.ReceivedBytes(1) + }() + } + wg.Wait() + if l.ExchangeCount() != numberOfExpectedExchanges { + t.Fail() + } +}