Skip to content

Commit

Permalink
Support cancellation for file/dir receives
Browse files Browse the repository at this point in the history
Canceling the context passed to Receive() will close the connection
and error on the next .Read().

Closes: #38 [via git-merge-pr]
  • Loading branch information
psanford committed Mar 14, 2021
1 parent 231ea97 commit 8d0bf85
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 0 deletions.
12 changes: 12 additions & 0 deletions wormhole/recv.go
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ func (c *Client) Receive(ctx context.Context, code string) (fr *IncomingMessage,
fr.UncompressedBytes = int(offer.File.FileSize)
fr.UncompressedBytes64 = offer.File.FileSize
fr.FileCount = 1
fr.ctx = ctx
} else if offer.Directory != nil {
fr.Type = TransferDirectory
fr.Name = offer.Directory.Dirname
Expand All @@ -139,6 +140,7 @@ func (c *Client) Receive(ctx context.Context, code string) (fr *IncomingMessage,
fr.UncompressedBytes = int(offer.Directory.NumBytes)
fr.UncompressedBytes64 = offer.Directory.NumBytes
fr.FileCount = int(offer.Directory.NumFiles)
fr.ctx = ctx
} else {
return nil, errors.New("got non-file transfer offer")
}
Expand Down Expand Up @@ -273,6 +275,8 @@ type IncomingMessage struct {
sha256 hash.Hash

readErr error

ctx context.Context
}

// Read the decrypted contents sent to this client.
Expand Down Expand Up @@ -324,6 +328,14 @@ func (f *IncomingMessage) readCrypt(p []byte) (int, error) {
return 0, f.readErr
}

if err := f.ctx.Err(); err != nil {
f.readErr = err
if f.cryptor != nil {
f.cryptor.Close()
}
return 0, err
}

if !f.transferInitialized {
f.transferInitialized = true
err := f.initializeTransfer()
Expand Down
63 changes: 63 additions & 0 deletions wormhole/wormhole_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"fmt"
"io"
"io/ioutil"
"log"
"net"
"path/filepath"
"strings"
Expand Down Expand Up @@ -342,6 +343,68 @@ func TestWormholeBigFileTransportSendRecvViaRelayServer(t *testing.T) {

}

func TestWormholeFileTransportRecvMidStreamCancel(t *testing.T) {
ctx := context.Background()

rs := rendezvousservertest.NewServer()
defer rs.Close()

url := rs.WebSocketURL()

testDisableLocalListener = true
defer func() { testDisableLocalListener = false }()

relayServer := newTestRelayServer()
defer relayServer.close()

var c0 Client
c0.RendezvousURL = url
c0.TransitRelayAddress = relayServer.addr

var c1 Client
c1.RendezvousURL = url
c1.TransitRelayAddress = relayServer.addr

fileContent := make([]byte, 1<<16)
for i := 0; i < len(fileContent); i++ {
fileContent[i] = byte(i)
}

buf := bytes.NewReader(fileContent)

code, resultCh, err := c0.SendFile(ctx, "file.txt", buf)
if err != nil {
t.Fatal(err)
}

childCtx, cancel := context.WithCancel(ctx)
defer cancel()

receiver, err := c1.Receive(childCtx, code)
if err != nil {
t.Fatal(err)
}

initialBuffer := make([]byte, 1<<10)

_, err = io.ReadFull(receiver, initialBuffer)
if err != nil {
log.Fatal(err)
}

cancel()

_, err = ioutil.ReadAll(receiver)
if err == nil {
log.Fatalf("Expected read error but got none")
}

result := <-resultCh
if result.OK {
t.Fatalf("Expected error result but got ok")
}
}

func TestWormholeDirectoryTransportSendRecvDirect(t *testing.T) {
ctx := context.Background()

Expand Down

0 comments on commit 8d0bf85

Please sign in to comment.