diff --git a/dht_test.go b/dht_test.go index 22eb952d275..97a48dc4166 100644 --- a/dht_test.go +++ b/dht_test.go @@ -301,6 +301,110 @@ func TestValueSetInvalid(t *testing.T) { testSetGet("valid", true, "newer", nil) } +func TestSearchValue(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + dhtA := setupDHT(ctx, t, false) + dhtB := setupDHT(ctx, t, false) + + defer dhtA.Close() + defer dhtB.Close() + defer dhtA.host.Close() + defer dhtB.host.Close() + + connect(t, ctx, dhtA, dhtB) + + dhtA.Validator.(record.NamespacedValidator)["v"] = testValidator{} + dhtB.Validator.(record.NamespacedValidator)["v"] = testValidator{} + + ctxT, cancel := context.WithTimeout(ctx, time.Second) + defer cancel() + + err := dhtA.PutValue(ctxT, "/v/hello", []byte("valid")) + if err != nil { + t.Error(err) + } + + ctxT, cancel = context.WithTimeout(ctx, time.Second*2) + defer cancel() + valCh, err := dhtA.SearchValue(ctxT, "/v/hello", Quorum(-1)) + if err != nil { + t.Fatal(err) + } + + select { + case v := <-valCh: + if string(v) != "valid" { + t.Errorf("expected 'valid', got '%s'", string(v)) + } + case <-ctxT.Done(): + t.Fatal(ctxT.Err()) + } + + err = dhtB.PutValue(ctxT, "/v/hello", []byte("newer")) + if err != nil { + t.Error(err) + } + + select { + case v := <-valCh: + if string(v) != "newer" { + t.Errorf("expected 'newer', got '%s'", string(v)) + } + case <-ctxT.Done(): + t.Fatal(ctxT.Err()) + } +} + +func TestGetValues(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + dhtA := setupDHT(ctx, t, false) + dhtB := setupDHT(ctx, t, false) + + defer dhtA.Close() + defer dhtB.Close() + defer dhtA.host.Close() + defer dhtB.host.Close() + + connect(t, ctx, dhtA, dhtB) + + ctxT, cancel := context.WithTimeout(ctx, time.Second) + defer cancel() + + err := dhtB.PutValue(ctxT, "/v/hello", []byte("newer")) + if err != nil { + t.Error(err) + } + + err = dhtA.PutValue(ctxT, "/v/hello", []byte("valid")) + if err != nil { + t.Error(err) + } + + ctxT, cancel = context.WithTimeout(ctx, time.Second*2) + defer cancel() + vals, err := dhtA.GetValues(ctxT, "/v/hello", 16) + if err != nil { + t.Fatal(err) + } + + if len(vals) != 2 { + t.Fatalf("expected to get 2 values, got %d", len(vals)) + } + + sort.Slice(vals, func(i, j int) bool { return string(vals[i].Val) < string(vals[j].Val) }) + + if string(vals[0].Val) != "valid" { + t.Errorf("unexpected vals[0]: %s", string(vals[0].Val)) + } + if string(vals[1].Val) != "valid" { + t.Errorf("unexpected vals[1]: %s", string(vals[1].Val)) + } +} + func TestValueGetInvalid(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -332,7 +436,7 @@ func TestValueGetInvalid(t *testing.T) { defer cancel() valb, err := dhtB.GetValue(ctxT, "/v/hello") if err != experr { - t.Errorf("Set/Get %v: Expected %v error but got %v", val, experr, err) + t.Errorf("Set/Get %v: Expected '%v' error but got '%v'", val, experr, err) } else if err == nil && string(valb) != exp { t.Errorf("Expected '%v' got '%s'", exp, string(valb)) } @@ -1271,12 +1375,12 @@ func TestGetSetPluggedProtocol(t *testing.T) { err = dhtA.PutValue(ctx, "/v/cat", []byte("meow")) if err == nil || !strings.Contains(err.Error(), "failed to find any peer in table") { - t.Fatal("should not have been able to find any peers in routing table") + t.Fatalf("put should not have been able to find any peers in routing table, err:'%v'", err) } _, err = dhtB.GetValue(ctx, "/v/cat") if err == nil || !strings.Contains(err.Error(), "failed to find any peer in table") { - t.Fatal("should not have been able to find any peers in routing table") + t.Fatalf("get should not have been able to find any peers in routing table, err:'%v'", err) } }) } diff --git a/ext_test.go b/ext_test.go index ca93d8b1679..570e92b134b 100644 --- a/ext_test.go +++ b/ext_test.go @@ -2,7 +2,6 @@ package dht import ( "context" - "io" "math/rand" "testing" "time" @@ -37,6 +36,7 @@ func TestGetFailures(t *testing.T) { // Reply with failures to every message hosts[1].SetStreamHandler(d.protocols[0], func(s inet.Stream) { + time.Sleep(400 * time.Millisecond) s.Close() }) @@ -48,7 +48,7 @@ func TestGetFailures(t *testing.T) { err = merr[0] } - if err != io.EOF { + if err != context.DeadlineExceeded { t.Fatal("Got different error than we expected", err) } } else { diff --git a/routing.go b/routing.go index a808bf0d7a1..4aca3775bff 100644 --- a/routing.go +++ b/routing.go @@ -120,101 +120,192 @@ func (dht *IpfsDHT) GetValue(ctx context.Context, key string, opts ...ropts.Opti eip.Done() }() + // apply defaultQuorum if relevant var cfg ropts.Options if err := cfg.Apply(opts...); err != nil { return nil, err } + opts = append(opts, Quorum(getQuorum(&cfg, defaultQuorum))) - responsesNeeded := 0 - if !cfg.Offline { - responsesNeeded = getQuorum(&cfg) - } - - vals, err := dht.GetValues(ctx, key, responsesNeeded) + responses, err := dht.SearchValue(ctx, key, opts...) if err != nil { return nil, err } + var best []byte - recs := make([][]byte, 0, len(vals)) - for _, v := range vals { - if v.Val != nil { - recs = append(recs, v.Val) - } + for r := range responses { + best = r } - if len(recs) == 0 { + + if ctx.Err() != nil { + return best, ctx.Err() + } + + if best == nil { return nil, routing.ErrNotFound } + log.Debugf("GetValue %v %v", key, best) + return best, nil +} - i, err := dht.Validator.Select(key, recs) - if err != nil { +func (dht *IpfsDHT) SearchValue(ctx context.Context, key string, opts ...ropts.Option) (<-chan []byte, error) { + var cfg ropts.Options + if err := cfg.Apply(opts...); err != nil { return nil, err } - best := recs[i] - log.Debugf("GetValue %v %v", key, best) - if best == nil { - log.Errorf("GetValue yielded correct record with nil value.") - return nil, routing.ErrNotFound + responsesNeeded := 0 + if !cfg.Offline { + responsesNeeded = getQuorum(&cfg, -1) } - fixupRec := record.MakePutRecord(key, best) - for _, v := range vals { - // if someone sent us a different 'less-valid' record, lets correct them - if !bytes.Equal(v.Val, best) { - go func(v RecvdVal) { - if v.From == dht.self { - err := dht.putLocal(key, fixupRec) - if err != nil { - log.Error("Error correcting local dht entry:", err) - } + valCh, err := dht.getValues(ctx, key, responsesNeeded) + if err != nil { + return nil, err + } + + out := make(chan []byte) + go func() { + defer close(out) + + maxVals := responsesNeeded + if maxVals < 0 { + maxVals = defaultQuorum * 4 // we want some upper bound on how + // much correctional entries we will send + } + + // vals is used collect entries we got so far and send corrections to peers + // when we exit this function + vals := make([]RecvdVal, 0, maxVals) + var best *RecvdVal + + defer func() { + if len(vals) <= 1 || best == nil { + return + } + fixupRec := record.MakePutRecord(key, best.Val) + for _, v := range vals { + // if someone sent us a different 'less-valid' record, lets correct them + if !bytes.Equal(v.Val, best.Val) { + go func(v RecvdVal) { + if v.From == dht.self { + err := dht.putLocal(key, fixupRec) + if err != nil { + log.Error("Error correcting local dht entry:", err) + } + return + } + ctx, cancel := context.WithTimeout(dht.Context(), time.Second*30) + defer cancel() + err := dht.putValueToPeer(ctx, v.From, fixupRec) + if err != nil { + log.Debug("Error correcting DHT entry: ", err) + } + }(v) + } + } + }() + + for { + select { + case v, ok := <-valCh: + if !ok { return } - ctx, cancel := context.WithTimeout(dht.Context(), time.Second*30) - defer cancel() - err := dht.putValueToPeer(ctx, v.From, fixupRec) - if err != nil { - log.Debug("Error correcting DHT entry: ", err) + + if len(vals) < maxVals { + vals = append(vals, v) + } + + if v.Val == nil { + continue + } + // Select best value + if best != nil { + sel, err := dht.Validator.Select(key, [][]byte{best.Val, v.Val}) + if err != nil { + log.Warning("Failed to select dht key: ", err) + continue + } + if sel == 1 && !bytes.Equal(v.Val, best.Val) { + best = &v + select { + case out <- v.Val: + case <-ctx.Done(): + return + } + } + } else { + // Output first valid value + if err := dht.Validator.Validate(key, v.Val); err == nil { + best = &v + select { + case out <- v.Val: + case <-ctx.Done(): + return + } + } } - }(v) + case <-ctx.Done(): + return + } } - } + }() - return best, nil + return out, nil } // GetValues gets nvals values corresponding to the given key. func (dht *IpfsDHT) GetValues(ctx context.Context, key string, nvals int) (_ []RecvdVal, err error) { eip := log.EventBegin(ctx, "GetValues") - defer func() { - eip.Append(loggableKey(key)) - if err != nil { - eip.SetError(err) - } - eip.Done() - }() - vals := make([]RecvdVal, 0, nvals) - var valslock sync.Mutex + + eip.Append(loggableKey(key)) + defer eip.Done() + + valCh, err := dht.getValues(ctx, key, nvals) + if err != nil { + eip.SetError(err) + return nil, err + } + + out := make([]RecvdVal, 0, nvals) + for val := range valCh { + out = append(out, val) + } + + return out, ctx.Err() +} + +func (dht *IpfsDHT) getValues(ctx context.Context, key string, nvals int) (<-chan RecvdVal, error) { + vals := make(chan RecvdVal, 1) + + done := func(err error) (<-chan RecvdVal, error) { + defer close(vals) + return vals, err + } // If we have it local, don't bother doing an RPC! lrec, err := dht.getLocal(key) if err != nil { // something is wrong with the datastore. - return nil, err + return done(err) } if lrec != nil { // TODO: this is tricky, we don't always want to trust our own value // what if the authoritative source updated it? log.Debug("have it locally") - vals = append(vals, RecvdVal{ + vals <- RecvdVal{ Val: lrec.GetValue(), From: dht.self, - }) + } - if nvals <= 1 { - return vals, nil + if nvals == 0 || nvals == 1 { + return done(nil) } + + nvals-- } else if nvals == 0 { - return nil, routing.ErrNotFound + return done(routing.ErrNotFound) } // get closest peers in the routing table @@ -222,9 +313,12 @@ func (dht *IpfsDHT) GetValues(ctx context.Context, key string, nvals int) (_ []R log.Debugf("peers in rt: %d %s", len(rtp), rtp) if len(rtp) == 0 { log.Warning("No peers from routing table!") - return nil, kb.ErrLookupFailure + return done(kb.ErrLookupFailure) } + var valslock sync.Mutex + var got int + // setup the Query parent := ctx query := dht.newQuery(key, func(ctx context.Context, p peer.ID) (*dhtQueryResult, error) { @@ -259,10 +353,16 @@ func (dht *IpfsDHT) GetValues(ctx context.Context, key string, nvals int) (_ []R From: p, } valslock.Lock() - vals = append(vals, rv) + select { + case vals <- rv: + case <-ctx.Done(): + valslock.Unlock() + return nil, ctx.Err() + } + got++ // If we have collected enough records, we're done - if len(vals) >= nvals { + if nvals == got { res.success = true } valslock.Unlock() @@ -277,18 +377,23 @@ func (dht *IpfsDHT) GetValues(ctx context.Context, key string, nvals int) (_ []R return res, nil }) - reqCtx, cancel := context.WithTimeout(ctx, time.Minute) - defer cancel() - _, err = query.Run(reqCtx, rtp) + go func() { + reqCtx, cancel := context.WithTimeout(ctx, time.Minute) + defer cancel() - // We do have some values but we either ran out of peers to query or - // searched for a whole minute. - // - // We'll just call this a success. - if len(vals) > 0 && (err == routing.ErrNotFound || reqCtx.Err() == context.DeadlineExceeded) { - err = nil - } - return vals, err + _, err = query.Run(reqCtx, rtp) + + // We do have some values but we either ran out of peers to query or + // searched for a whole minute. + // + // We'll just call this a success. + if got > 0 && (err == routing.ErrNotFound || reqCtx.Err() == context.DeadlineExceeded) { + err = nil + } + done(err) + }() + + return vals, nil } // Provider abstraction for indirect stores. diff --git a/routing_options.go b/routing_options.go index e2b7af1f974..46083ea562e 100644 --- a/routing_options.go +++ b/routing_options.go @@ -6,6 +6,8 @@ import ( type quorumOptionKey struct{} +const defaultQuorum = 16 + // Quorum is a DHT option that tells the DHT how many peers it needs to get // values from before returning the best one. // @@ -20,10 +22,10 @@ func Quorum(n int) ropts.Option { } } -func getQuorum(opts *ropts.Options) int { +func getQuorum(opts *ropts.Options, ndefault int) int { responsesNeeded, ok := opts.Other[quorumOptionKey{}].(int) if !ok { - responsesNeeded = 16 + responsesNeeded = ndefault } return responsesNeeded }