diff --git a/datatransfer/impl/graphsync/graphsync.go b/datatransfer/impl/graphsync/graphsync.go index b00b8c79e86..4ede42fee05 100644 --- a/datatransfer/impl/graphsync/graphsync.go +++ b/datatransfer/impl/graphsync/graphsync.go @@ -4,7 +4,6 @@ import ( "bytes" "context" "fmt" - cidlink "github.com/ipld/go-ipld-prime/linking/cid" "math/rand" "reflect" @@ -14,6 +13,7 @@ import ( "github.com/ipld/go-ipld-prime" "github.com/ipld/go-ipld-prime/encoding/dagcbor" ipldfree "github.com/ipld/go-ipld-prime/impl/free" + cidlink "github.com/ipld/go-ipld-prime/linking/cid" "github.com/libp2p/go-libp2p-core/host" "github.com/libp2p/go-libp2p-core/peer" @@ -109,24 +109,28 @@ func (impl *graphsyncImpl) OpenPushDataChannel(ctx context.Context, to peer.ID, if err != nil { return datatransfer.ChannelID{}, err } - chid := impl.createNewChannel(tid, to, baseCid, selector, voucher) + chid := impl.createNewChannel(tid, to, baseCid, selector, voucher, "", to) return chid, nil } // OpenPullDataChannel opens a data transfer that will request data from the sending peer and // transfer parts of the piece that match the selector func (impl *graphsyncImpl) OpenPullDataChannel(ctx context.Context, to peer.ID, voucher datatransfer.Voucher, baseCid cid.Cid, selector ipld.Node) (datatransfer.ChannelID, error) { + tid, err := impl.sendRequest(ctx, selector, true, voucher, baseCid, to) if err != nil { return datatransfer.ChannelID{}, err } - chid := impl.createNewChannel(tid, to, baseCid, selector, voucher) + chid := impl.createNewChannel(tid, to, baseCid, selector, voucher, to, "") return chid, nil } // createNewChannel creates a new channel id -func (impl *graphsyncImpl) createNewChannel(tid datatransfer.TransferID, to peer.ID, baseCid cid.Cid, selector ipld.Node, voucher datatransfer.Voucher) datatransfer.ChannelID { - return datatransfer.ChannelID{To: to, ID: tid} +func (impl *graphsyncImpl) createNewChannel(tid datatransfer.TransferID, to peer.ID, baseCid cid.Cid, selector ipld.Node, voucher datatransfer.Voucher, sender, receiver peer.ID) datatransfer.ChannelID { + chid := datatransfer.ChannelID{To: to, ID: tid} + chst := datatransfer.ChannelState{Channel: datatransfer.NewChannel(0, baseCid, selector, voucher, sender, receiver, 0)} + impl.channels[chid] = chst + return chid } // sendRequest encapsulates message creation and posting to the data transfer network with the provided parameters @@ -283,7 +287,19 @@ func (receiver *graphsyncReceiver) ReceiveResponse( if !incoming.Accepted() { evt = datatransfer.Error } else { - evt = datatransfer.Progress // for now + chid := datatransfer.ChannelID{ + To: sender, + ID: incoming.TransferID(), + } + channel, ok := receiver.impl.channels[chid] + if ok { + baseCid := channel.BaseCID() + root := cidlink.Link{baseCid} + go func() { + receiver.impl.gs.Request(ctx, sender, root, channel.Selector()) + }() + } + evt = datatransfer.Progress } receiver.impl.notifySubscribers(evt, datatransfer.ChannelState{}) } diff --git a/datatransfer/impl/graphsync/graphsync_test.go b/datatransfer/impl/graphsync/graphsync_test.go index 146af386cf6..adc571854a4 100644 --- a/datatransfer/impl/graphsync/graphsync_test.go +++ b/datatransfer/impl/graphsync/graphsync_test.go @@ -4,6 +4,7 @@ import ( "bytes" "context" "errors" + "fmt" "io" "io/ioutil" "math/rand" @@ -775,84 +776,97 @@ func TestDataTransferInitiatingPushGraphsyncRequests(t *testing.T) { // TODO: get passing to complete https://github.com/filecoin-project/go-data-transfer/issues/21 func TestDataTransferInitiatingPullGraphsyncRequests(t *testing.T) { - //ctx := context.Background() - //ctx, cancel := context.WithTimeout(ctx, 10*time.Second) - //defer cancel() - //gsData := newGraphsyncTestingData(t, ctx) - //host1 := gsData.host1 - //host2 := gsData.host2 - // - //gs2 := &fakeGraphSync{ - // receivedRequests: make(chan receivedGraphSyncRequest, 1), - //} - //voucher := fakeDTType{"applesauce"} - //baseCid := testutil.GenerateCids(1)[0] - // - //gs1 := &fakeGraphSync{ - // receivedRequests: make(chan receivedGraphSyncRequest, 1), - //} - //dt1 := NewGraphSyncDataTransfer(ctx, host1, gs1) - // - //t.Run("with successful validation", func(t *testing.T) { - // sv := newSV() - // sv.expectSuccessPull() - // - // dt2 := NewGraphSyncDataTransfer(ctx, host2, gs2) - // err := dt2.RegisterVoucherType(reflect.TypeOf(&fakeDTType{}), sv) - // require.NoError(t, err) - // - // _, err = dt1.OpenPullDataChannel(ctx, host2.ID(), &voucher, baseCid, gsData.allSelector) - // require.NoError(t, err) - // - // var requestReceived receivedGraphSyncRequest - // select { - // case <-ctx.Done(): - // t.Fatal("did not receive message sent") - // case requestReceived = <-gs1.receivedRequests: - // } - // - // sv.verifyExpectations(t) - // - // receiver := requestReceived.p - // require.Equal(t, receiver, host2.ID()) - // - // cl, ok := requestReceived.root.(cidlink.Link) - // require.True(t, ok) - // require.Equal(t, baseCid, cl.Cid) - // - // require.Equal(t, gsData.allSelector, requestReceived.selector) - //}) - // - //t.Run("with error validation", func(t *testing.T) { - // sv := newSV() - // sv.expectErrorPull() - // - // dt2 := NewGraphSyncDataTransfer(ctx, host2, gs2) - // err := dt2.RegisterVoucherType(reflect.TypeOf(&fakeDTType{}), sv) - // require.NoError(t, err) - // - // subscribeCalls := make(chan struct{}, 1) - // subscribe := func(event datatransfer.Event, channelState datatransfer.ChannelState) { - // if event == datatransfer.Error { - // subscribeCalls <- struct{}{} - // } - // } - // unsub := dt1.SubscribeToEvents(subscribe) - // _, err = dt1.OpenPullDataChannel(ctx, host2.ID(), &voucher, baseCid, gsData.allSelector) - // require.NoError(t, err) - // - // select { - // case <-ctx.Done(): - // t.Fatal("subscribed events not received") - // case <-subscribeCalls: - // } - // - // sv.verifyExpectations(t) - // - // // no graphsync request should be scheduled - // require.Empty(t, gs1.receivedRequests) - // unsub() - //}) + ctx := context.Background() + gsData := newGraphsyncTestingData(t, ctx) + host1 := gsData.host1 + host2 := gsData.host2 + + voucher := fakeDTType{"applesauce"} + baseCid := testutil.GenerateCids(1)[0] + + t.Run("with successful validation", func(t *testing.T) { + gs1 := &fakeGraphSync{ + receivedRequests: make(chan receivedGraphSyncRequest, 1), + } + gs2 := &fakeGraphSync{ + receivedRequests: make(chan receivedGraphSyncRequest, 1), + } + + sv := newSV() + sv.expectSuccessPull() + + bg := ctx + ctx, cancel := context.WithTimeout(ctx, 10*time.Second) + defer cancel() + + dt1 := NewGraphSyncDataTransfer(bg, host1, gs1) + dt2 := NewGraphSyncDataTransfer(bg, host2, gs2) + err := dt2.RegisterVoucherType(reflect.TypeOf(&fakeDTType{}), sv) + require.NoError(t, err) + + _, err = dt1.OpenPullDataChannel(ctx, host2.ID(), &voucher, baseCid, gsData.allSelector) + require.NoError(t, err) + + var requestReceived receivedGraphSyncRequest + select { + case <-ctx.Done(): + t.Fatal("did not receive message sent") + case requestReceived = <-gs1.receivedRequests: + } + // give a little time for the validation to happen + time.Sleep(15*time.Millisecond) + sv.verifyExpectations(t) + + receiver := requestReceived.p + require.Equal(t, receiver, host2.ID()) + + cl, ok := requestReceived.root.(cidlink.Link) + require.True(t, ok) + require.Equal(t, baseCid, cl.Cid) + + require.Equal(t, gsData.allSelector, requestReceived.selector) + }) + + t.Run("with error validation", func(t *testing.T) { + gs1 := &fakeGraphSync{ + receivedRequests: make(chan receivedGraphSyncRequest, 1), + } + gs2 := &fakeGraphSync{ + receivedRequests: make(chan receivedGraphSyncRequest, 1), + } + + dt1 := NewGraphSyncDataTransfer(ctx, host1, gs1) + sv := newSV() + sv.expectErrorPull() + + dt2 := NewGraphSyncDataTransfer(ctx, host2, gs2) + err := dt2.RegisterVoucherType(reflect.TypeOf(&fakeDTType{}), sv) + require.NoError(t, err) + + subscribeCalls := make(chan struct{}, 1) + subscribe := func(event datatransfer.Event, channelState datatransfer.ChannelState) { + if event == datatransfer.Error { + subscribeCalls <- struct{}{} + } + } + unsub := dt1.SubscribeToEvents(subscribe) + _, err = dt1.OpenPullDataChannel(ctx, host2.ID(), &voucher, baseCid, gsData.allSelector) + require.NoError(t, err) + + select { + case <-ctx.Done(): + t.Fatal("subscribed events not received") + case <-subscribeCalls: + } + + // give a little time for the validation to happen + time.Sleep(15*time.Millisecond) + sv.verifyExpectations(t) + + // no graphsync request should be scheduled + require.Empty(t, gs1.receivedRequests) + unsub() + }) } type receivedGraphSyncMessage struct { @@ -1347,11 +1361,13 @@ type fakeGraphSync struct { // Request initiates a new GraphSync request to the given peer using the given selector spec. func (fgs *fakeGraphSync) Request(ctx context.Context, p peer.ID, root ipld.Link, selector ipld.Node, extensions ...graphsync.ExtensionData) (<-chan graphsync.ResponseProgress, <-chan error) { + fgs.receivedRequests <- receivedGraphSyncRequest{p, root, selector, extensions} responses := make(chan graphsync.ResponseProgress) errors := make(chan error) close(responses) close(errors) + fmt.Println("fakeGraphSync.Request") return responses, errors }