Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: allow disabling value and provider storage/messages #400

Merged
merged 5 commits into from
Dec 12, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions dht.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,11 @@ type IpfsDHT struct {
triggerRtRefresh chan struct{}

maxRecordAge time.Duration

// Allows disabling dht subsystems. These should _only_ be set on
// "forked" DHTs (e.g., DHTs with custom protocols and/or private
// networks).
enableProviders, enableValues bool
}

// Assert that IPFS assumptions about interfaces aren't broken. These aren't a
Expand All @@ -98,6 +103,8 @@ func New(ctx context.Context, h host.Host, options ...opts.Option) (*IpfsDHT, er
dht.rtRefreshQueryTimeout = cfg.RoutingTable.RefreshQueryTimeout

dht.maxRecordAge = cfg.MaxRecordAge
dht.enableProviders = cfg.EnableProviders
dht.enableValues = cfg.EnableValues

// register for network notifs.
dht.host.Network().Notify((*netNotifiee)(dht))
Expand Down
76 changes: 72 additions & 4 deletions dht_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,13 +107,15 @@ func (testAtomicPutValidator) Select(_ string, bs [][]byte) (int, error) {
return index, nil
}

func setupDHT(ctx context.Context, t *testing.T, client bool) *IpfsDHT {
func setupDHT(ctx context.Context, t *testing.T, client bool, options ...opts.Option) *IpfsDHT {
d, err := New(
ctx,
bhost.New(swarmt.GenSwarm(t, ctx, swarmt.OptDisableReuseport)),
opts.Client(client),
opts.NamespacedValidator("v", blankValidator{}),
opts.DisableAutoRefresh(),
append([]opts.Option{
opts.Client(client),
opts.NamespacedValidator("v", blankValidator{}),
opts.DisableAutoRefresh(),
}, options...)...,
)
if err != nil {
t.Fatal(err)
Expand Down Expand Up @@ -1419,6 +1421,72 @@ func TestFindClosestPeers(t *testing.T) {
}
}

func TestProvideDisabled(t *testing.T) {
k := testCaseCids[0]
for i := 0; i < 3; i++ {
enabledA := (i & 0x1) > 0
enabledB := (i & 0x2) > 0
t.Run(fmt.Sprintf("a=%v/b=%v", enabledA, enabledB), func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

var (
optsA, optsB []opts.Option
)
if !enabledA {
optsA = append(optsA, opts.DisableProviders())
}
if !enabledB {
optsB = append(optsB, opts.DisableProviders())
}

dhtA := setupDHT(ctx, t, false, optsA...)
dhtB := setupDHT(ctx, t, false, optsB...)

defer dhtA.Close()
defer dhtB.Close()
defer dhtA.host.Close()
defer dhtB.host.Close()

connect(t, ctx, dhtA, dhtB)

err := dhtB.Provide(ctx, k, true)
if enabledB {
if err != nil {
t.Fatal("put should have succeeded on node B", err)
}
} else {
if err != routing.ErrNotSupported {
t.Fatal("should not have put the value to node B", err)
}
_, err = dhtB.FindProviders(ctx, k)
if err != routing.ErrNotSupported {
t.Fatal("get should have failed on node B")
}
provs := dhtB.providers.GetProviders(ctx, k)
if len(provs) != 0 {
t.Fatal("node B should not have found local providers")
}
}

provs, err := dhtA.FindProviders(ctx, k)
if enabledA {
if len(provs) != 0 {
t.Fatal("node A should not have found providers")
}
} else {
if err != routing.ErrNotSupported {
t.Fatal("node A should not have found providers")
}
}
provAddrs := dhtA.providers.GetProviders(ctx, k)
if len(provAddrs) != 0 {
t.Fatal("node A should not have found local providers")
}
})
}
}

func TestGetSetPluggedProtocol(t *testing.T) {
t.Run("PutValue/GetValue - same protocol", func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
Expand Down
30 changes: 20 additions & 10 deletions handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,21 +26,31 @@ type dhtHandler func(context.Context, peer.ID, *pb.Message) (*pb.Message, error)

func (dht *IpfsDHT) handlerForMsgType(t pb.Message_MessageType) dhtHandler {
switch t {
case pb.Message_GET_VALUE:
return dht.handleGetValue
case pb.Message_PUT_VALUE:
return dht.handlePutValue
case pb.Message_FIND_NODE:
return dht.handleFindPeer
case pb.Message_ADD_PROVIDER:
return dht.handleAddProvider
case pb.Message_GET_PROVIDERS:
return dht.handleGetProviders
case pb.Message_PING:
return dht.handlePing
default:
return nil
}

if dht.enableValues {
switch t {
case pb.Message_GET_VALUE:
return dht.handleGetValue
case pb.Message_PUT_VALUE:
return dht.handlePutValue
}
}

if dht.enableProviders {
switch t {
case pb.Message_ADD_PROVIDER:
return dht.handleAddProvider
case pb.Message_GET_PROVIDERS:
return dht.handleGetProviders
}
}

return nil
}

func (dht *IpfsDHT) handleGetValue(ctx context.Context, p peer.ID, pmes *pb.Message) (_ *pb.Message, err error) {
Expand Down
43 changes: 37 additions & 6 deletions opts/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,14 @@ var (

// Options is a structure containing all the options that can be used when constructing a DHT.
type Options struct {
Datastore ds.Batching
Validator record.Validator
Client bool
Protocols []protocol.ID
BucketSize int
MaxRecordAge time.Duration
Datastore ds.Batching
Validator record.Validator
Client bool
Protocols []protocol.ID
BucketSize int
MaxRecordAge time.Duration
EnableProviders bool
EnableValues bool

RoutingTable struct {
RefreshQueryTimeout time.Duration
Expand Down Expand Up @@ -56,6 +58,8 @@ var Defaults = func(o *Options) error {
}
o.Datastore = dssync.MutexWrap(ds.NewMapDatastore())
o.Protocols = DefaultProtocols
o.EnableProviders = true
o.EnableValues = true

o.RoutingTable.RefreshQueryTimeout = 10 * time.Second
o.RoutingTable.RefreshPeriod = 1 * time.Hour
Expand Down Expand Up @@ -177,3 +181,30 @@ func DisableAutoRefresh() Option {
return nil
}
}

// DisableProviders disables storing and retrieving provider records.
//
// Defaults to enabled.
//
// WARNING: do not change this unless you're using a forked DHT (i.e., a private
// network and/or distinct DHT protocols with the `Protocols` option).
func DisableProviders() Option {
return func(o *Options) error {
o.EnableProviders = false
return nil
}
}

// DisableProviders disables storing and retrieving value records (including
// public keys).
//
// Defaults to enabled.
//
// WARNING: do not change this unless you're using a forked DHT (i.e., a private
// network and/or distinct DHT protocols with the `Protocols` option).
func DisableValues() Option {
return func(o *Options) error {
o.EnableValues = false
return nil
}
}
75 changes: 75 additions & 0 deletions records_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package dht
import (
"context"
"crypto/rand"
"fmt"
"github.com/libp2p/go-libp2p-core/test"
"testing"
"time"
Expand All @@ -13,6 +14,8 @@ import (
"github.com/libp2p/go-libp2p-core/routing"
record "github.com/libp2p/go-libp2p-record"
tnet "github.com/libp2p/go-libp2p-testing/net"

dhtopt "github.com/libp2p/go-libp2p-kad-dht/opts"
)

// Check that GetPublicKey() correctly extracts a public key
Expand Down Expand Up @@ -305,3 +308,75 @@ func TestPubkeyGoodKeyFromDHTGoodKeyDirect(t *testing.T) {
t.Fatal("got incorrect public key")
}
}

func TestValuesDisabled(t *testing.T) {
for i := 0; i < 3; i++ {
enabledA := (i & 0x1) > 0
enabledB := (i & 0x2) > 0
t.Run(fmt.Sprintf("a=%v/b=%v", enabledA, enabledB), func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

var (
optsA, optsB []dhtopt.Option
)
if !enabledA {
optsA = append(optsA, dhtopt.DisableValues())
}
if !enabledB {
optsB = append(optsB, dhtopt.DisableValues())
}

dhtA := setupDHT(ctx, t, false, optsA...)
dhtB := setupDHT(ctx, t, false, optsB...)

defer dhtA.Close()
defer dhtB.Close()
defer dhtA.host.Close()
defer dhtB.host.Close()

connect(t, ctx, dhtA, dhtB)

pubk := dhtB.peerstore.PubKey(dhtB.self)
pkbytes, err := pubk.Bytes()
if err != nil {
t.Fatal(err)
}

pkkey := routing.KeyForPublicKey(dhtB.self)
err = dhtB.PutValue(ctx, pkkey, pkbytes)
if enabledB {
if err != nil {
t.Fatal("put should have succeeded on node B", err)
}
} else {
if err != routing.ErrNotSupported {
t.Fatal("should not have put the value to node B", err)
}
_, err = dhtB.GetValue(ctx, pkkey)
if err != routing.ErrNotSupported {
t.Fatal("get should have failed on node B")
}
rec, _ := dhtB.getLocal(pkkey)
if rec != nil {
t.Fatal("node B should not have found the value locally")
}
}

_, err = dhtA.GetValue(ctx, pkkey)
if enabledA {
if err != routing.ErrNotFound {
t.Fatal("node A should not have found the value")
}
} else {
if err != routing.ErrNotSupported {
t.Fatal("node A should not have found the value")
}
}
rec, _ := dhtA.getLocal(pkkey)
if rec != nil {
t.Fatal("node A should not have found the value locally")
}
})
}
}
30 changes: 28 additions & 2 deletions routing.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@ var asyncQueryBuffer = 10
// PutValue adds value corresponding to given Key.
// This is the top level "Store" operation of the DHT
func (dht *IpfsDHT) PutValue(ctx context.Context, key string, value []byte, opts ...routing.Option) (err error) {
if !dht.enableValues {
return routing.ErrNotSupported
}

eip := logger.EventBegin(ctx, "PutValue")
defer func() {
eip.Append(loggableKey(key))
Expand Down Expand Up @@ -110,6 +114,10 @@ type RecvdVal struct {

// GetValue searches for the value corresponding to given Key.
func (dht *IpfsDHT) GetValue(ctx context.Context, key string, opts ...routing.Option) (_ []byte, err error) {
if !dht.enableValues {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this correct behaviour? We might know a peer that supports GetValue and so on.

return nil, routing.ErrNotSupported
}

eip := logger.EventBegin(ctx, "GetValue")
defer func() {
eip.Append(loggableKey(key))
Expand Down Expand Up @@ -148,6 +156,10 @@ func (dht *IpfsDHT) GetValue(ctx context.Context, key string, opts ...routing.Op
}

func (dht *IpfsDHT) SearchValue(ctx context.Context, key string, opts ...routing.Option) (<-chan []byte, error) {
if !dht.enableValues {
return nil, routing.ErrNotSupported
}

var cfg routing.Options
if err := cfg.Apply(opts...); err != nil {
return nil, err
Expand Down Expand Up @@ -250,8 +262,11 @@ func (dht *IpfsDHT) SearchValue(ctx context.Context, key string, opts ...routing

// GetValues gets nvals values corresponding to the given key.
func (dht *IpfsDHT) GetValues(ctx context.Context, key string, nvals int) (_ []RecvdVal, err error) {
eip := logger.EventBegin(ctx, "GetValues")
if !dht.enableValues {
return nil, routing.ErrNotSupported
}

eip := logger.EventBegin(ctx, "GetValues")
eip.Append(loggableKey(key))
defer eip.Done()

Expand Down Expand Up @@ -398,6 +413,9 @@ func (dht *IpfsDHT) getValues(ctx context.Context, key string, nvals int) (<-cha

// Provide makes this node announce that it can provide a value for the given key
func (dht *IpfsDHT) Provide(ctx context.Context, key cid.Cid, brdcst bool) (err error) {
if !dht.enableProviders {
return routing.ErrNotSupported
}
eip := logger.EventBegin(ctx, "Provide", key, logging.LoggableMap{"broadcast": brdcst})
defer func() {
if err != nil {
Expand Down Expand Up @@ -477,6 +495,9 @@ func (dht *IpfsDHT) makeProvRecord(skey cid.Cid) (*pb.Message, error) {

// FindProviders searches until the context expires.
func (dht *IpfsDHT) FindProviders(ctx context.Context, c cid.Cid) ([]peer.AddrInfo, error) {
if !dht.enableProviders {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here.

return nil, routing.ErrNotSupported
}
var providers []peer.AddrInfo
for p := range dht.FindProvidersAsync(ctx, c, dht.bucketSize) {
providers = append(providers, p)
Expand All @@ -488,8 +509,13 @@ func (dht *IpfsDHT) FindProviders(ctx context.Context, c cid.Cid) ([]peer.AddrIn
// Peers will be returned on the channel as soon as they are found, even before
// the search query completes.
func (dht *IpfsDHT) FindProvidersAsync(ctx context.Context, key cid.Cid, count int) <-chan peer.AddrInfo {
logger.Event(ctx, "findProviders", key)
peerOut := make(chan peer.AddrInfo, count)
if !dht.enableProviders {
close(peerOut)
return peerOut
}

logger.Event(ctx, "findProviders", key)

go dht.findProvidersAsyncRoutine(ctx, key, count, peerOut)
return peerOut
Expand Down