diff --git a/pstoreds/ds_test.go b/pstoreds/ds_test.go index 01741b1..b76bc09 100644 --- a/pstoreds/ds_test.go +++ b/pstoreds/ds_test.go @@ -7,6 +7,8 @@ import ( "testing" "time" + "github.com/stretchr/testify/require" + ds "github.com/ipfs/go-datastore" badger "github.com/ipfs/go-ds-badger" leveldb "github.com/ipfs/go-ds-leveldb" @@ -27,6 +29,18 @@ func TestDsPeerstore(t *testing.T) { t.Run(name, func(t *testing.T) { pt.TestPeerstore(t, peerstoreFactory(t, dsFactory, DefaultOpts())) }) + + t.Run("protobook limits", func(t *testing.T) { + const limit = 10 + opts := DefaultOpts() + opts.MaxProtocols = limit + ds, close := dsFactory(t) + defer close() + ps, err := NewPeerstore(context.Background(), ds, opts) + require.NoError(t, err) + defer ps.Close() + pt.TestPeerstoreProtoStoreLimits(t, ps, limit) + }) } } diff --git a/pstoreds/peerstore.go b/pstoreds/peerstore.go index bea64dc..6351e9d 100644 --- a/pstoreds/peerstore.go +++ b/pstoreds/peerstore.go @@ -6,13 +6,13 @@ import ( "io" "time" - base32 "github.com/multiformats/go-base32" + "github.com/multiformats/go-base32" ds "github.com/ipfs/go-datastore" - query "github.com/ipfs/go-datastore/query" + "github.com/ipfs/go-datastore/query" - peer "github.com/libp2p/go-libp2p-core/peer" - peerstore "github.com/libp2p/go-libp2p-core/peerstore" + "github.com/libp2p/go-libp2p-core/peer" + "github.com/libp2p/go-libp2p-core/peerstore" pstore "github.com/libp2p/go-libp2p-peerstore" ) @@ -21,6 +21,9 @@ type Options struct { // The size of the in-memory cache. A value of 0 or lower disables the cache. CacheSize uint + // MaxProtocols is the maximum number of protocols we store for one peer. + MaxProtocols int + // Sweep interval to purge expired addresses from the datastore. If this is a zero value, GC will not run // automatically, but it'll be available on demand via explicit calls. GCPurgeInterval time.Duration @@ -37,12 +40,14 @@ type Options struct { // DefaultOpts returns the default options for a persistent peerstore, with the full-purge GC algorithm: // // * Cache size: 1024. +// * MaxProtocols: 1024. // * GC purge interval: 2 hours. // * GC lookahead interval: disabled. // * GC initial delay: 60 seconds. func DefaultOpts() Options { return Options{ CacheSize: 1024, + MaxProtocols: 1024, GCPurgeInterval: 2 * time.Hour, GCLookaheadInterval: 0, GCInitialDelay: 60 * time.Second, @@ -75,7 +80,10 @@ func NewPeerstore(ctx context.Context, store ds.Batching, opts Options) (*pstore return nil, err } - protoBook := NewProtoBook(peerMetadata) + protoBook, err := NewProtoBook(peerMetadata, WithMaxProtocols(opts.MaxProtocols)) + if err != nil { + return nil, err + } ps := &pstoreds{ Metrics: pstore.NewMetrics(), diff --git a/pstoreds/protobook.go b/pstoreds/protobook.go index 6518c4e..d4c0bce 100644 --- a/pstoreds/protobook.go +++ b/pstoreds/protobook.go @@ -1,10 +1,11 @@ package pstoreds import ( + "errors" "fmt" "sync" - peer "github.com/libp2p/go-libp2p-core/peer" + "github.com/libp2p/go-libp2p-core/peer" pstore "github.com/libp2p/go-libp2p-core/peerstore" ) @@ -19,15 +20,27 @@ func (s *protoSegments) get(p peer.ID) *protoSegment { return s[byte(p[len(p)-1])] } +var errTooManyProtocols = errors.New("too many protocols") + +type ProtoBookOption func(*dsProtoBook) error + +func WithMaxProtocols(num int) ProtoBookOption { + return func(pb *dsProtoBook) error { + pb.maxProtos = num + return nil + } +} + type dsProtoBook struct { - segments protoSegments - meta pstore.PeerMetadata + segments protoSegments + meta pstore.PeerMetadata + maxProtos int } var _ pstore.ProtoBook = (*dsProtoBook)(nil) -func NewProtoBook(meta pstore.PeerMetadata) *dsProtoBook { - return &dsProtoBook{ +func NewProtoBook(meta pstore.PeerMetadata, opts ...ProtoBookOption) (*dsProtoBook, error) { + pb := &dsProtoBook{ meta: meta, segments: func() (ret protoSegments) { for i := range ret { @@ -35,23 +48,34 @@ func NewProtoBook(meta pstore.PeerMetadata) *dsProtoBook { } return ret }(), + maxProtos: 1024, } + + for _, opt := range opts { + if err := opt(pb); err != nil { + return nil, err + } + } + return pb, nil } func (pb *dsProtoBook) SetProtocols(p peer.ID, protos ...string) error { if err := p.Validate(); err != nil { return err } - - s := pb.segments.get(p) - s.Lock() - defer s.Unlock() + if len(protos) > pb.maxProtos { + return errTooManyProtocols + } protomap := make(map[string]struct{}, len(protos)) for _, proto := range protos { protomap[proto] = struct{}{} } + s := pb.segments.get(p) + s.Lock() + defer s.Unlock() + return pb.meta.Put(p, "protocols", protomap) } @@ -68,6 +92,9 @@ func (pb *dsProtoBook) AddProtocols(p peer.ID, protos ...string) error { if err != nil { return err } + if len(pmap)+len(protos) > pb.maxProtos { + return errTooManyProtocols + } for _, proto := range protos { pmap[proto] = struct{}{} diff --git a/pstoremem/inmem_test.go b/pstoremem/inmem_test.go index 7491a93..1cc584e 100644 --- a/pstoremem/inmem_test.go +++ b/pstoremem/inmem_test.go @@ -3,6 +3,8 @@ package pstoremem import ( "testing" + "github.com/stretchr/testify/require" + pstore "github.com/libp2p/go-libp2p-core/peerstore" pt "github.com/libp2p/go-libp2p-peerstore/test" @@ -13,42 +15,56 @@ func TestFuzzInMemoryPeerstore(t *testing.T) { // Just create and close a bunch of peerstores. If this leaks, we'll // catch it in the leak check below. for i := 0; i < 100; i++ { - ps := NewPeerstore() + ps, err := NewPeerstore() + require.NoError(t, err) ps.Close() } } func TestInMemoryPeerstore(t *testing.T) { pt.TestPeerstore(t, func() (pstore.Peerstore, func()) { - ps := NewPeerstore() + ps, err := NewPeerstore() + require.NoError(t, err) return ps, func() { ps.Close() } }) } +func TestPeerstoreProtoStoreLimits(t *testing.T) { + const limit = 10 + ps, err := NewPeerstore(WithMaxProtocols(limit)) + require.NoError(t, err) + defer ps.Close() + pt.TestPeerstoreProtoStoreLimits(t, ps, limit) +} + func TestInMemoryAddrBook(t *testing.T) { pt.TestAddrBook(t, func() (pstore.AddrBook, func()) { - ps := NewPeerstore() + ps, err := NewPeerstore() + require.NoError(t, err) return ps, func() { ps.Close() } }) } func TestInMemoryKeyBook(t *testing.T) { pt.TestKeyBook(t, func() (pstore.KeyBook, func()) { - ps := NewPeerstore() + ps, err := NewPeerstore() + require.NoError(t, err) return ps, func() { ps.Close() } }) } func BenchmarkInMemoryPeerstore(b *testing.B) { pt.BenchmarkPeerstore(b, func() (pstore.Peerstore, func()) { - ps := NewPeerstore() + ps, err := NewPeerstore() + require.NoError(b, err) return ps, func() { ps.Close() } }, "InMem") } func BenchmarkInMemoryKeyBook(b *testing.B) { pt.BenchmarkKeyBook(b, func() (pstore.KeyBook, func()) { - ps := NewPeerstore() + ps, err := NewPeerstore() + require.NoError(b, err) return ps, func() { ps.Close() } }) } diff --git a/pstoremem/peerstore.go b/pstoremem/peerstore.go index 113b1d7..dc29149 100644 --- a/pstoremem/peerstore.go +++ b/pstoremem/peerstore.go @@ -2,10 +2,11 @@ package pstoremem import ( "fmt" + "io" + "github.com/libp2p/go-libp2p-core/peer" "github.com/libp2p/go-libp2p-core/peerstore" pstore "github.com/libp2p/go-libp2p-peerstore" - "io" ) type pstoremem struct { @@ -17,15 +18,26 @@ type pstoremem struct { *memoryPeerMetadata } +func WithMaxProtocols(num int) Option { + return func(pb *memoryProtoBook) error { + pb.maxProtos = num + return nil + } +} + // NewPeerstore creates an in-memory threadsafe collection of peers. -func NewPeerstore() *pstoremem { +func NewPeerstore(opts ...Option) (*pstoremem, error) { + pb, err := NewProtoBook(opts...) + if err != nil { + return nil, err + } return &pstoremem{ Metrics: pstore.NewMetrics(), memoryKeyBook: NewKeyBook(), memoryAddrBook: NewAddrBook(), - memoryProtoBook: NewProtoBook(), + memoryProtoBook: pb, memoryPeerMetadata: NewPeerMetadata(), - } + }, nil } func (ps *pstoremem) Close() (err error) { diff --git a/pstoremem/protobook.go b/pstoremem/protobook.go index 1042825..f2fbe50 100644 --- a/pstoremem/protobook.go +++ b/pstoremem/protobook.go @@ -1,9 +1,10 @@ package pstoremem import ( + "errors" "sync" - peer "github.com/libp2p/go-libp2p-core/peer" + "github.com/libp2p/go-libp2p-core/peer" pstore "github.com/libp2p/go-libp2p-core/peerstore" ) @@ -19,17 +20,23 @@ func (s *protoSegments) get(p peer.ID) *protoSegment { return s[byte(p[len(p)-1])] } +var errTooManyProtocols = errors.New("too many protocols") + +type Option func(*memoryProtoBook) error + type memoryProtoBook struct { segments protoSegments + maxProtos int + lk sync.RWMutex interned map[string]string } var _ pstore.ProtoBook = (*memoryProtoBook)(nil) -func NewProtoBook() *memoryProtoBook { - return &memoryProtoBook{ +func NewProtoBook(opts ...Option) (*memoryProtoBook, error) { + pb := &memoryProtoBook{ interned: make(map[string]string, 256), segments: func() (ret protoSegments) { for i := range ret { @@ -39,7 +46,15 @@ func NewProtoBook() *memoryProtoBook { } return ret }(), + maxProtos: 1024, + } + + for _, opt := range opts { + if err := opt(pb); err != nil { + return nil, err + } } + return pb, nil } func (pb *memoryProtoBook) internProtocol(proto string) string { @@ -70,17 +85,19 @@ func (pb *memoryProtoBook) SetProtocols(p peer.ID, protos ...string) error { if err := p.Validate(); err != nil { return err } - - s := pb.segments.get(p) - s.Lock() - defer s.Unlock() + if len(protos) > pb.maxProtos { + return errTooManyProtocols + } newprotos := make(map[string]struct{}, len(protos)) for _, proto := range protos { newprotos[pb.internProtocol(proto)] = struct{}{} } + s := pb.segments.get(p) + s.Lock() s.protocols[p] = newprotos + s.Unlock() return nil } @@ -99,11 +116,13 @@ func (pb *memoryProtoBook) AddProtocols(p peer.ID, protos ...string) error { protomap = make(map[string]struct{}) s.protocols[p] = protomap } + if len(protomap)+len(protos) > pb.maxProtos { + return errTooManyProtocols + } for _, proto := range protos { protomap[pb.internProtocol(proto)] = struct{}{} } - return nil } diff --git a/test/peerstore_suite.go b/test/peerstore_suite.go index ec0a39c..bdfd9ab 100644 --- a/test/peerstore_suite.go +++ b/test/peerstore_suite.go @@ -425,3 +425,23 @@ func getAddrs(t *testing.T, n int) []ma.Multiaddr { } return addrs } + +func TestPeerstoreProtoStoreLimits(t *testing.T, ps pstore.Peerstore, limit int) { + p := peer.ID("foobar") + protocols := make([]string, limit) + for i := 0; i < limit; i++ { + protocols[i] = fmt.Sprintf("protocol %d", i) + } + + t.Run("setting protocols", func(t *testing.T) { + require.NoError(t, ps.SetProtocols(p, protocols...)) + require.EqualError(t, ps.SetProtocols(p, append(protocols, "proto")...), "too many protocols") + }) + t.Run("adding protocols", func(t *testing.T) { + p1 := protocols[:limit/2] + p2 := protocols[limit/2:] + require.NoError(t, ps.SetProtocols(p, p1...)) + require.NoError(t, ps.AddProtocols(p, p2...)) + require.EqualError(t, ps.AddProtocols(p, "proto"), "too many protocols") + }) +}