diff --git a/wormhole/recv.go b/wormhole/recv.go index e9c92366..d6fcc05f 100644 --- a/wormhole/recv.go +++ b/wormhole/recv.go @@ -10,7 +10,6 @@ import ( "fmt" "hash" "io" - "log" "github.com/psanford/wormhole-william/internal/crypto" "github.com/psanford/wormhole-william/rendezvous" @@ -347,7 +346,6 @@ func (f *IncomingMessage) readCrypt(p []byte) (int, error) { if len(f.buf) == 0 { rec, err := f.cryptor.readRecord() if err == io.EOF { - log.Printf("unexpected eof! reclen=%d totallen=%d", len(rec), f.readCount) f.readErr = io.ErrUnexpectedEOF return 0, f.readErr } else if err != nil { diff --git a/wormhole/send.go b/wormhole/send.go index 8611d099..900f7de5 100644 --- a/wormhole/send.go +++ b/wormhole/send.go @@ -373,6 +373,15 @@ func (c *Client) sendFileDirectory(ctx context.Context, offer *offerMsg, r io.Re totalSize = offer.Directory.ZipSize } + var cancel func() + ctx, cancel = context.WithCancel(ctx) + defer cancel() + + go func() { + <-ctx.Done() + conn.Close() + }() + for { n, err := r.Read(recordSlice) if n > 0 { diff --git a/wormhole/wormhole_test.go b/wormhole/wormhole_test.go index ede534e7..7a95d3eb 100644 --- a/wormhole/wormhole_test.go +++ b/wormhole/wormhole_test.go @@ -8,12 +8,12 @@ import ( "fmt" "io" "io/ioutil" - "log" "net" "path/filepath" "strings" "sync" "testing" + "time" "github.com/klauspost/compress/zip" "github.com/psanford/wormhole-william/rendezvous/rendezvousservertest" @@ -288,7 +288,6 @@ func TestWormholeFileTransportSendRecvViaRelayServer(t *testing.T) { if !result.OK { t.Fatalf("Expected ok result but got: %+v", result) } - } func TestWormholeBigFileTransportSendRecvViaRelayServer(t *testing.T) { @@ -389,14 +388,14 @@ func TestWormholeFileTransportRecvMidStreamCancel(t *testing.T) { _, err = io.ReadFull(receiver, initialBuffer) if err != nil { - log.Fatal(err) + t.Fatal(err) } cancel() _, err = ioutil.ReadAll(receiver) if err == nil { - log.Fatalf("Expected read error but got none") + t.Fatalf("Expected read error but got none") } result := <-resultCh @@ -405,6 +404,62 @@ func TestWormholeFileTransportRecvMidStreamCancel(t *testing.T) { } } +func TestWormholeFileTransportSendMidStreamCancel(t *testing.T) { + ctx := context.Background() + + rs := rendezvousservertest.NewServer() + defer rs.Close() + + url := rs.WebSocketURL() + + testDisableLocalListener = true + defer func() { testDisableLocalListener = false }() + + relayServer := newTestRelayServer() + defer relayServer.close() + + var c0 Client + c0.RendezvousURL = url + c0.TransitRelayAddress = relayServer.addr + + var c1 Client + c1.RendezvousURL = url + c1.TransitRelayAddress = relayServer.addr + + fileContent := make([]byte, 1<<16) + for i := 0; i < len(fileContent); i++ { + fileContent[i] = byte(i) + } + + sendCtx, cancel := context.WithCancel(ctx) + + splitR := splitReader{ + Reader: bytes.NewReader(fileContent), + cancelAt: 1 << 10, + cancel: cancel, + } + + code, resultCh, err := c0.SendFile(sendCtx, "file.txt", &splitR) + if err != nil { + t.Fatal(err) + } + + receiver, err := c1.Receive(ctx, code) + if err != nil { + t.Fatal(err) + } + + _, err = ioutil.ReadAll(receiver) + if err == nil { + t.Fatal("Expected read error but got none") + } + + result := <-resultCh + if result.OK { + t.Fatal("Expected send resultCh to error but got none") + } +} + func TestWormholeDirectoryTransportSendRecvDirect(t *testing.T) { ctx := context.Background() @@ -653,3 +708,24 @@ func (ts *testRelayServer) handleConn(c net.Conn) { existing.Close() } } + +type splitReader struct { + *bytes.Reader + offset int + cancelAt int + cancel func() + didCancel bool +} + +func (s *splitReader) Read(b []byte) (int, error) { + n, err := s.Reader.Read(b) + s.offset += n + if !s.didCancel && s.offset >= s.cancelAt { + s.cancel() + s.didCancel = true + // yield the cpu to give the cancellation goroutine a chance + // to run (esp important for when GOMAXPROCS=1) + time.Sleep(1 * time.Millisecond) + } + return n, err +}