Skip to content

Commit

Permalink
sendFileDirectory: cancel transfers via context
Browse files Browse the repository at this point in the history
Canceling the context will now cancel an in-flight file transfer.
This should cover most cases. There might still be some edge cases
where cancellation doesn't work.

Closes: #40 [via git-merge-pr]
  • Loading branch information
psanford committed Mar 14, 2021
1 parent 8d0bf85 commit 90fd6cd
Show file tree
Hide file tree
Showing 3 changed files with 89 additions and 6 deletions.
2 changes: 0 additions & 2 deletions wormhole/recv.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ import (
"fmt"
"hash"
"io"
"log"

"github.com/psanford/wormhole-william/internal/crypto"
"github.com/psanford/wormhole-william/rendezvous"
Expand Down Expand Up @@ -347,7 +346,6 @@ func (f *IncomingMessage) readCrypt(p []byte) (int, error) {
if len(f.buf) == 0 {
rec, err := f.cryptor.readRecord()
if err == io.EOF {
log.Printf("unexpected eof! reclen=%d totallen=%d", len(rec), f.readCount)
f.readErr = io.ErrUnexpectedEOF
return 0, f.readErr
} else if err != nil {
Expand Down
9 changes: 9 additions & 0 deletions wormhole/send.go
Original file line number Diff line number Diff line change
Expand Up @@ -373,6 +373,15 @@ func (c *Client) sendFileDirectory(ctx context.Context, offer *offerMsg, r io.Re
totalSize = offer.Directory.ZipSize
}

var cancel func()
ctx, cancel = context.WithCancel(ctx)
defer cancel()

go func() {
<-ctx.Done()
conn.Close()
}()

for {
n, err := r.Read(recordSlice)
if n > 0 {
Expand Down
84 changes: 80 additions & 4 deletions wormhole/wormhole_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@ import (
"fmt"
"io"
"io/ioutil"
"log"
"net"
"path/filepath"
"strings"
"sync"
"testing"
"time"

"github.com/klauspost/compress/zip"
"github.com/psanford/wormhole-william/rendezvous/rendezvousservertest"
Expand Down Expand Up @@ -288,7 +288,6 @@ func TestWormholeFileTransportSendRecvViaRelayServer(t *testing.T) {
if !result.OK {
t.Fatalf("Expected ok result but got: %+v", result)
}

}

func TestWormholeBigFileTransportSendRecvViaRelayServer(t *testing.T) {
Expand Down Expand Up @@ -389,14 +388,14 @@ func TestWormholeFileTransportRecvMidStreamCancel(t *testing.T) {

_, err = io.ReadFull(receiver, initialBuffer)
if err != nil {
log.Fatal(err)
t.Fatal(err)
}

cancel()

_, err = ioutil.ReadAll(receiver)
if err == nil {
log.Fatalf("Expected read error but got none")
t.Fatalf("Expected read error but got none")
}

result := <-resultCh
Expand All @@ -405,6 +404,62 @@ func TestWormholeFileTransportRecvMidStreamCancel(t *testing.T) {
}
}

func TestWormholeFileTransportSendMidStreamCancel(t *testing.T) {
ctx := context.Background()

rs := rendezvousservertest.NewServer()
defer rs.Close()

url := rs.WebSocketURL()

testDisableLocalListener = true
defer func() { testDisableLocalListener = false }()

relayServer := newTestRelayServer()
defer relayServer.close()

var c0 Client
c0.RendezvousURL = url
c0.TransitRelayAddress = relayServer.addr

var c1 Client
c1.RendezvousURL = url
c1.TransitRelayAddress = relayServer.addr

fileContent := make([]byte, 1<<16)
for i := 0; i < len(fileContent); i++ {
fileContent[i] = byte(i)
}

sendCtx, cancel := context.WithCancel(ctx)

splitR := splitReader{
Reader: bytes.NewReader(fileContent),
cancelAt: 1 << 10,
cancel: cancel,
}

code, resultCh, err := c0.SendFile(sendCtx, "file.txt", &splitR)
if err != nil {
t.Fatal(err)
}

receiver, err := c1.Receive(ctx, code)
if err != nil {
t.Fatal(err)
}

_, err = ioutil.ReadAll(receiver)
if err == nil {
t.Fatal("Expected read error but got none")
}

result := <-resultCh
if result.OK {
t.Fatal("Expected send resultCh to error but got none")
}
}

func TestWormholeDirectoryTransportSendRecvDirect(t *testing.T) {
ctx := context.Background()

Expand Down Expand Up @@ -653,3 +708,24 @@ func (ts *testRelayServer) handleConn(c net.Conn) {
existing.Close()
}
}

type splitReader struct {
*bytes.Reader
offset int
cancelAt int
cancel func()
didCancel bool
}

func (s *splitReader) Read(b []byte) (int, error) {
n, err := s.Reader.Read(b)
s.offset += n
if !s.didCancel && s.offset >= s.cancelAt {
s.cancel()
s.didCancel = true
// yield the cpu to give the cancellation goroutine a chance
// to run (esp important for when GOMAXPROCS=1)
time.Sleep(1 * time.Millisecond)
}
return n, err
}

0 comments on commit 90fd6cd

Please sign in to comment.