diff --git a/wormhole/file_transport.go b/wormhole/file_transport.go index 2ee7cc95..3790c94e 100644 --- a/wormhole/file_transport.go +++ b/wormhole/file_transport.go @@ -14,10 +14,12 @@ import ( "net" "strconv" "time" + "net/url" "github.com/psanford/wormhole-william/internal/crypto" "golang.org/x/crypto/hkdf" "golang.org/x/crypto/nacl/secretbox" + "nhooyr.io/websocket" ) type fileTransportAck struct { @@ -33,6 +35,10 @@ const ( TransferText ) +// UnsupportedProtocolErr is used in the default case of protocol switch +// statements to account for unexpected protocols. +var UnsupportedProtocolErr = errors.New("unsupported protocol") + func (tt TransferType) String() string { switch tt { case TransferFile: @@ -154,18 +160,18 @@ func (d *transportCryptor) writeRecord(msg []byte) error { return err } -func newFileTransport(transitKey []byte, appID, relayAddr string) *fileTransport { +func newFileTransport(transitKey []byte, appID string, relayURL *url.URL) *fileTransport { return &fileTransport{ transitKey: transitKey, appID: appID, - relayAddr: relayAddr, + relayURL: relayURL, } } type fileTransport struct { listener net.Listener relayConn net.Conn - relayAddr string + relayURL *url.URL transitKey []byte appID string } @@ -177,19 +183,21 @@ func (t *fileTransport) connectViaRelay(otherTransit *transitMsg) (net.Conn, err failChan := make(chan string) var count int - - for _, outerHint := range otherTransit.HintsV1 { - if outerHint.Type == "relay-v1" { - for _, innerHint := range outerHint.Hints { - if innerHint.Type == "direct-tcp-v1" { - count++ - ctx, cancel := context.WithCancel(context.Background()) - addr := net.JoinHostPort(innerHint.Hostname, strconv.Itoa(innerHint.Port)) - - cancelFuncs[addr] = cancel - - go t.connectToRelay(ctx, addr, successChan, failChan) + for _, relay := range otherTransit.HintsV1 { + if relay.Type == "relay-v1" { + for _, endpoint := range relay.Hints { + var addr string + switch endpoint.Type { + case "direct-tcp-v1": + addr = net.JoinHostPort(endpoint.Hostname, strconv.Itoa(endpoint.Port)) + case "websocket": + addr = endpoint.Url } + ctx, cancel := context.WithCancel(context.Background()) + cancelFuncs[addr] = cancel + + count++ + go t.connectToRelay(ctx, successChan, failChan) } } } @@ -250,12 +258,30 @@ func (t *fileTransport) connectDirect(otherTransit *transitMsg) (net.Conn, error return conn, nil } -func (t *fileTransport) connectToRelay(ctx context.Context, addr string, successChan chan net.Conn, failChan chan string) { +func (t *fileTransport) connectToRelay(ctx context.Context, successChan chan net.Conn, failChan chan string) { var d net.Dialer - conn, err := d.DialContext(ctx, "tcp", addr) - if err != nil { - failChan <- addr - return + var conn net.Conn + var err error + addr := fmt.Sprintf("%s:%s", t.relayURL.Hostname(), t.relayURL.Port()) + + switch t.relayURL.Scheme { + case "tcp": + conn, err = d.DialContext(ctx, "tcp", addr) + + if err != nil { + failChan <- addr + return + } + case "ws", "wss": + var wsconn *websocket.Conn + wsconn, _, err = websocket.Dial(ctx, t.relayURL.String(), nil) + + if err != nil { + failChan <- addr + return + } + + conn = websocket.NetConn(ctx, wsconn, websocket.MessageBinary) } _, err = conn.Write(t.relayHandshakeHeader()) @@ -277,7 +303,7 @@ func (t *fileTransport) connectToRelay(ctx context.Context, addr string, success return } - t.directRecvHandshake(ctx, addr, conn, successChan, failChan) + t.directRecvHandshake(ctx, conn, successChan, failChan) } func (t *fileTransport) connectToSingleHost(ctx context.Context, addr string, successChan chan net.Conn, failChan chan string) { @@ -289,12 +315,13 @@ func (t *fileTransport) connectToSingleHost(ctx context.Context, addr string, su return } - t.directRecvHandshake(ctx, addr, conn, successChan, failChan) + t.directRecvHandshake(ctx, conn, successChan, failChan) } -func (t *fileTransport) directRecvHandshake(ctx context.Context, addr string, conn net.Conn, successChan chan net.Conn, failChan chan string) { +func (t *fileTransport) directRecvHandshake(ctx context.Context, conn net.Conn, successChan chan net.Conn, failChan chan string) { expectHeader := t.senderHandshakeHeader() + addr := t.relayURL.Hostname() gotHeader := make([]byte, len(expectHeader)) _, err := io.ReadFull(conn, gotHeader) @@ -372,27 +399,43 @@ func (t *fileTransport) makeTransitMsg() (*transitMsg, error) { } if t.relayConn != nil { - relayHost, portStr, err := net.SplitHostPort(t.relayAddr) - if err != nil { - return nil, err + var relayType string + switch t.relayURL.Scheme { + case "tcp": + relayType = "direct-tcp-v1" + case "ws": + relayType = "websocket" + case "wss": + relayType = "websocket" + default: + return nil, fmt.Errorf("%w: %s", UnsupportedProtocolErr, t.relayURL.Scheme) } - - relayPort, err := strconv.Atoi(portStr) - if err != nil { - return nil, fmt.Errorf("port isn't an integer? %s", portStr) - } - - msg.HintsV1 = append(msg.HintsV1, transitHintsV1{ - Type: "relay-v1", - Hints: []transitHintsV1Hint{ - { - Type: "direct-tcp-v1", - Priority: 2.0, - Hostname: relayHost, - Port: relayPort, + if relayType == "direct-tcp-v1" { + var port, err = strconv.Atoi(t.relayURL.Port()) + if err != nil { + return nil, fmt.Errorf("invalid port") + } + msg.HintsV1 = append(msg.HintsV1, transitHintsV1{ + Type: "relay-v1", + Hints: []transitHintsRelay{ + { + Type: relayType, + Hostname: t.relayURL.Hostname(), + Port: port, + }, }, - }, - }) + }) + } else { + msg.HintsV1 = append(msg.HintsV1, transitHintsV1{ + Type: "relay-v1", + Hints: []transitHintsRelay{ + { + Type: relayType, + Url: t.relayURL.String(), + }, + }, + }) + } } return &msg, nil @@ -449,23 +492,44 @@ func (t *fileTransport) listen() error { if testDisableLocalListener { return nil } + switch t.relayURL.Scheme { + case "tcp": + l, err := net.Listen("tcp", ":0") + if err != nil { + return err + } - l, err := net.Listen("tcp", ":0") - if err != nil { - return err + t.listener = l + case "ws", "wss": + t.listener = nil + default: + return fmt.Errorf("%w: %s", UnsupportedProtocolErr, t.relayURL.Scheme) } - t.listener = l return nil } -func (t *fileTransport) listenRelay() error { - if t.relayAddr == "" { - return nil - } - conn, err := net.Dial("tcp", t.relayAddr) - if err != nil { - return err +func (t *fileTransport) listenRelay(ctx context.Context) (err error) { + var conn net.Conn + switch t.relayURL.Scheme { + case "tcp": + // NB: don't dial the relay if we don't have an address. + addr := t.relayURL.Host + if addr == ":0" { + return nil + } + conn, err = net.Dial("tcp", addr) + if err != nil { + return err + } + case "ws", "wss": + c, _, err := websocket.Dial(ctx, t.relayURL.String(), nil) + if err != nil { + return fmt.Errorf("websocket.Dial failed") + } + conn = websocket.NetConn(ctx, c, websocket.MessageBinary) + default: + return fmt.Errorf("%w: %s", UnsupportedProtocolErr, t.relayURL.Scheme) } _, err = conn.Write(t.relayHandshakeHeader()) diff --git a/wormhole/recv.go b/wormhole/recv.go index 230e8484..888c5f4f 100644 --- a/wormhole/recv.go +++ b/wormhole/recv.go @@ -157,7 +157,11 @@ func (c *Client) Receive(ctx context.Context, code string) (fr *IncomingMessage, } transitKey := deriveTransitKey(clientProto.sharedKey, appID) - transport := newFileTransport(transitKey, appID, c.relayAddr()) + relayUrl, err := c.relayURL() + if err != nil { + return nil, fmt.Errorf("Invalid relay URL") + } + transport := newFileTransport(transitKey, appID, relayUrl) transitMsg, err := transport.makeTransitMsg() if err != nil { diff --git a/wormhole/send.go b/wormhole/send.go index b4d5a3d5..327bdc46 100644 --- a/wormhole/send.go +++ b/wormhole/send.go @@ -185,10 +185,6 @@ func (c *Client) SendText(ctx context.Context, msg string, opts ...SendOption) ( } func (c *Client) sendFileDirectory(ctx context.Context, offer *offerMsg, r io.Reader, opts ...SendOption) (string, chan SendResult, error) { - if err := c.validateRelayAddr(); err != nil { - return "", nil, fmt.Errorf("invalid TransitRelayAddress: %s", err) - } - var options sendOptions for _, opt := range opts { err := opt.setOption(&options) @@ -296,15 +292,20 @@ func (c *Client) sendFileDirectory(ctx context.Context, offer *offerMsg, r io.Re } } + var relayUrl, err = c.relayURL() + if err != nil { + sendErr(fmt.Errorf("Invalid relay URL")) + return + } transitKey := deriveTransitKey(clientProto.sharedKey, appID) - transport := newFileTransport(transitKey, appID, c.relayAddr()) + transport := newFileTransport(transitKey, appID, relayUrl) err = transport.listen() if err != nil { sendErr(err) return } - err = transport.listenRelay() + err = transport.listenRelay(ctx) if err != nil { sendErr(err) return @@ -443,10 +444,6 @@ func (c *Client) sendFileDirectory(ctx context.Context, offer *offerMsg, r io.Re // receiver, a result channel that will be written to after the receiver attempts to read (either successfully or not) // and an error if one occurred. func (c *Client) SendFile(ctx context.Context, fileName string, r io.ReadSeeker, opts ...SendOption) (string, chan SendResult, error) { - if err := c.validateRelayAddr(); err != nil { - return "", nil, fmt.Errorf("invalid TransitRelayAddress: %s", err) - } - size, err := readSeekerSize(r) if err != nil { return "", nil, err diff --git a/wormhole/wormhole.go b/wormhole/wormhole.go index e947bf4b..c23c578a 100644 --- a/wormhole/wormhole.go +++ b/wormhole/wormhole.go @@ -9,11 +9,11 @@ import ( "errors" "fmt" "io" - "net" "reflect" "strconv" "strings" "sync" + "net/url" "github.com/psanford/wormhole-william/internal/crypto" "github.com/psanford/wormhole-william/rendezvous" @@ -33,10 +33,10 @@ type Client struct { // DefaultRendezvousURL will be used. RendezvousURL string - // TransitRelayAddress is the host:port address to offer + // TransitRelayURL is the proto://host:port address to offer // to use for file transfers where direct connections are unavailable. - // If empty, DefaultTransitRelayAddress will be used. - TransitRelayAddress string + // If empty, DefaultTransitRelayURL will be used. + TransitRelayURL string // PassPhraseComponentLength is the number of words to use // when generating a passprase. Any value less than 2 will @@ -63,8 +63,8 @@ var ( // DefaultRendezvousURL is the default Rendezvous server to use. DefaultRendezvousURL = "ws://relay.magic-wormhole.io:4000/v1" - // DefaultTransitRelayAddress is the default transit server to ues. - DefaultTransitRelayAddress = "transit.magic-wormhole.io:4001" + // DefaultTransitRelayURL is the default transit server to ues. + DefaultTransitRelayURL = "tcp://transit.magic-wormhole.io:4001" ) func (c *Client) url() string { @@ -89,19 +89,12 @@ func (c *Client) wordCount() int { } } -func (c *Client) relayAddr() string { - if c.TransitRelayAddress != "" { - return c.TransitRelayAddress +func (c *Client) relayURL() (*url.URL, error) { + var rurl = c.TransitRelayURL + if rurl == "" { + rurl = DefaultTransitRelayURL } - return DefaultTransitRelayAddress -} - -func (c *Client) validateRelayAddr() error { - if c.relayAddr() == "" { - return nil - } - _, _, err := net.SplitHostPort(c.relayAddr()) - return err + return url.Parse(rurl) } // SendResult has information about whether or not a Send command was successful. @@ -268,18 +261,24 @@ type transitAbility struct { } type transitHintsV1 struct { + Type string `json:"type"` + // When type is "direct-tcp-v1" or "tor-tcp-v1" Hostname string `json:"hostname"` Port int `json:"port"` Priority float64 `json:"priority"` - Type string `json:"type"` - Hints []transitHintsV1Hint `json:"hints"` + // When type is "relay-v1" + Name string `json:"name,omitempty"` + Hints []transitHintsRelay `json:"hints"` } -type transitHintsV1Hint struct { - Hostname string `json:"hostname"` - Port int `json:"port"` - Priority float64 `json:"priority"` +type transitHintsRelay struct { Type string `json:"type"` + // When type is "direct-tcp-v1" + Hostname string `json:"hostname,omitempty"` + Port int `json:"port,omitempty"` + Priority float64 `json:"priority,omitempty"` + // When type is "websocket" + Url string `json:"url,omitempty"` } type transitMsg struct { diff --git a/wormhole/wormhole_test.go b/wormhole/wormhole_test.go index 79a37281..223ff2bb 100644 --- a/wormhole/wormhole_test.go +++ b/wormhole/wormhole_test.go @@ -9,9 +9,11 @@ import ( "io" "io/ioutil" "net" - "os" + "net/http" + "net/http/httptest" + "net/url" "path/filepath" - "runtime/pprof" + _ "runtime/pprof" "strings" "sync" "testing" @@ -21,8 +23,14 @@ import ( "github.com/psanford/wormhole-william/internal/crypto" "github.com/psanford/wormhole-william/rendezvous" "github.com/psanford/wormhole-william/rendezvous/rendezvousservertest" + "nhooyr.io/websocket" ) +var relayServerConstructors = map[string]func() *testRelayServer{ + "TCP": newTestTCPRelayServer, + "WS": newTestWSRelayServer, +} + func TestWormholeSendRecvText(t *testing.T) { ctx := context.Background() @@ -32,7 +40,7 @@ func TestWormholeSendRecvText(t *testing.T) { url := rs.WebSocketURL() // disable transit relay - DefaultTransitRelayAddress = "" + DefaultTransitRelayURL = "" var c0Verifier string var c0 Client @@ -163,7 +171,7 @@ func TestVerifierAbort(t *testing.T) { url := rs.WebSocketURL() // disable transit relay - DefaultTransitRelayAddress = "" + DefaultTransitRelayURL = "" var c0 Client c0.RendezvousURL = url @@ -206,7 +214,7 @@ func TestWormholeFileReject(t *testing.T) { url := rs.WebSocketURL() // disable transit relay for this test - DefaultTransitRelayAddress = "" + DefaultTransitRelayURL = "" var c0 Client c0.RendezvousURL = url @@ -251,46 +259,51 @@ func TestWormholeFileTransportSendRecvViaRelayServer(t *testing.T) { testDisableLocalListener = true defer func() { testDisableLocalListener = false }() - relayServer := newTestRelayServer() - defer relayServer.close() + for relayProtocol, newRelayServer := range relayServerConstructors { + t.Run(fmt.Sprintf("With %s relay server", relayProtocol), func(t *testing.T) { + relayServer := newRelayServer() + relayURL := relayServer.url.String() + defer relayServer.close() - var c0 Client - c0.RendezvousURL = url - c0.TransitRelayAddress = relayServer.addr + var c0 Client + c0.RendezvousURL = url + c0.TransitRelayURL = relayURL - var c1 Client - c1.RendezvousURL = url - c1.TransitRelayAddress = relayServer.addr + var c1 Client + c1.RendezvousURL = url + c1.TransitRelayURL = relayURL - fileContent := make([]byte, 1<<16) - for i := 0; i < len(fileContent); i++ { - fileContent[i] = byte(i) - } + fileContent := make([]byte, 1<<16) + for i := 0; i < len(fileContent); i++ { + fileContent[i] = byte(i) + } - buf := bytes.NewReader(fileContent) + buf := bytes.NewReader(fileContent) - code, resultCh, err := c0.SendFile(ctx, "file.txt", buf) - if err != nil { - t.Fatal(err) - } + code, resultCh, err := c0.SendFile(ctx, "file.txt", buf) + if err != nil { + t.Fatal(err) + } - receiver, err := c1.Receive(ctx, code) - if err != nil { - t.Fatal(err) - } + receiver, err := c1.Receive(ctx, code) + if err != nil { + t.Fatal(err) + } - got, err := ioutil.ReadAll(receiver) - if err != nil { - t.Fatal(err) - } + got, err := ioutil.ReadAll(receiver) + if err != nil { + t.Fatal(err) + } - if !bytes.Equal(got, fileContent) { - t.Fatalf("File contents mismatch") - } + if !bytes.Equal(got, fileContent) { + t.Fatalf("File contents mismatch") + } - result := <-resultCh - if !result.OK { - t.Fatalf("Expected ok result but got: %+v", result) + result := <-resultCh + if !result.OK { + t.Fatalf("Expected ok result but got: %+v", result) + } + }) } } @@ -305,45 +318,49 @@ func TestWormholeBigFileTransportSendRecvViaRelayServer(t *testing.T) { testDisableLocalListener = true defer func() { testDisableLocalListener = false }() - relayServer := newTestRelayServer() - defer relayServer.close() - - var c0 Client - c0.RendezvousURL = url - c0.TransitRelayAddress = relayServer.addr - - var c1 Client - c1.RendezvousURL = url - c1.TransitRelayAddress = relayServer.addr - - // Create a fake file offer - var fakeBigSize int64 = 32098461509 - offer := &offerMsg{ - File: &offerFile{ - FileName: "fakefile", - FileSize: fakeBigSize, - }, - } + for relayProtocol, newRelayServer := range relayServerConstructors { + t.Run(fmt.Sprintf("With %s relay server", relayProtocol), func(t *testing.T) { + relayServer := newRelayServer() + relayURL := relayServer.url.String() + defer relayServer.close() + + var c0 Client + c0.RendezvousURL = url + c0.TransitRelayURL = relayURL + + var c1 Client + c1.RendezvousURL = url + c1.TransitRelayURL = relayURL + + // Create a fake file offer + var fakeBigSize int64 = 32098461509 + offer := &offerMsg{ + File: &offerFile{ + FileName: "fakefile", + FileSize: fakeBigSize, + }, + } - // just a pretend reader - r := bytes.NewReader(make([]byte, 1)) + // just a pretend reader + r := bytes.NewReader(make([]byte, 1)) - // skip th wrapper so we can provide our own offer - code, _, err := c0.sendFileDirectory(ctx, offer, r) - //c0.SendFile(ctx, "file.txt", buf) - if err != nil { - t.Fatal(err) - } + // skip th wrapper so we can provide our own offer + code, _, err := c0.sendFileDirectory(ctx, offer, r) + //c0.SendFile(ctx, "file.txt", buf) + if err != nil { + t.Fatal(err) + } - receiver, err := c1.Receive(ctx, code) - if err != nil { - t.Fatal(err) - } + receiver, err := c1.Receive(ctx, code) + if err != nil { + t.Fatal(err) + } - if int64(receiver.TransferBytes64) != fakeBigSize { - t.Fatalf("Mismatch in size between what we are trying to send and what is (our parsed) offer. Expected %v but got %v", fakeBigSize, receiver.TransferBytes64) + if int64(receiver.TransferBytes64) != fakeBigSize { + t.Fatalf("Mismatch in size between what we are trying to send and what is (our parsed) offer. Expected %v but got %v", fakeBigSize, receiver.TransferBytes64) + } + }) } - } func TestWormholeFileTransportRecvMidStreamCancel(t *testing.T) { @@ -357,54 +374,59 @@ func TestWormholeFileTransportRecvMidStreamCancel(t *testing.T) { testDisableLocalListener = true defer func() { testDisableLocalListener = false }() - relayServer := newTestRelayServer() - defer relayServer.close() + for relayProtocol, newRelayServer := range relayServerConstructors { + t.Run(fmt.Sprintf("With %s relay server", relayProtocol), func(t *testing.T) { + relayServer := newRelayServer() + relayURL := relayServer.url.String() + defer relayServer.close() - var c0 Client - c0.RendezvousURL = url - c0.TransitRelayAddress = relayServer.addr + var c0 Client + c0.RendezvousURL = url + c0.TransitRelayURL = relayURL - var c1 Client - c1.RendezvousURL = url - c1.TransitRelayAddress = relayServer.addr + var c1 Client + c1.RendezvousURL = url + c1.TransitRelayURL = relayURL - fileContent := make([]byte, 1<<16) - for i := 0; i < len(fileContent); i++ { - fileContent[i] = byte(i) - } + fileContent := make([]byte, 1<<16) + for i := 0; i < len(fileContent); i++ { + fileContent[i] = byte(i) + } - buf := bytes.NewReader(fileContent) + buf := bytes.NewReader(fileContent) - code, resultCh, err := c0.SendFile(ctx, "file.txt", buf) - if err != nil { - t.Fatal(err) - } + code, resultCh, err := c0.SendFile(ctx, "file.txt", buf) + if err != nil { + t.Fatal(err) + } - childCtx, cancel := context.WithCancel(ctx) - defer cancel() + childCtx, cancel := context.WithCancel(ctx) + defer cancel() - receiver, err := c1.Receive(childCtx, code) - if err != nil { - t.Fatal(err) - } + receiver, err := c1.Receive(childCtx, code) + if err != nil { + t.Fatal(err) + } - initialBuffer := make([]byte, 1<<10) + initialBuffer := make([]byte, 1<<10) - _, err = io.ReadFull(receiver, initialBuffer) - if err != nil { - t.Fatal(err) - } + _, err = io.ReadFull(receiver, initialBuffer) + if err != nil { + t.Fatal(err) + } - cancel() + cancel() - _, err = ioutil.ReadAll(receiver) - if err == nil { - t.Fatalf("Expected read error but got none") - } + _, err = ioutil.ReadAll(receiver) + if err == nil { + t.Fatalf("Expected read error but got none") + } - result := <-resultCh - if result.OK { - t.Fatalf("Expected error result but got ok") + result := <-resultCh + if result.OK { + t.Fatalf("Expected error result but got ok") + } + }) } } @@ -419,48 +441,53 @@ func TestWormholeFileTransportSendMidStreamCancel(t *testing.T) { testDisableLocalListener = true defer func() { testDisableLocalListener = false }() - relayServer := newTestRelayServer() - defer relayServer.close() + for relayProtocol, newRelayServer := range relayServerConstructors { + t.Run(fmt.Sprintf("With %s relay server", relayProtocol), func(t *testing.T) { + relayServer := newRelayServer() + relayURL := relayServer.url.String() + defer relayServer.close() - var c0 Client - c0.RendezvousURL = url - c0.TransitRelayAddress = relayServer.addr + var c0 Client + c0.RendezvousURL = url + c0.TransitRelayURL = relayURL - var c1 Client - c1.RendezvousURL = url - c1.TransitRelayAddress = relayServer.addr + var c1 Client + c1.RendezvousURL = url + c1.TransitRelayURL = relayURL - fileContent := make([]byte, 1<<16) - for i := 0; i < len(fileContent); i++ { - fileContent[i] = byte(i) - } + fileContent := make([]byte, 1<<16) + for i := 0; i < len(fileContent); i++ { + fileContent[i] = byte(i) + } - sendCtx, cancel := context.WithCancel(ctx) + sendCtx, cancel := context.WithCancel(ctx) - splitR := splitReader{ - Reader: bytes.NewReader(fileContent), - cancelAt: 1 << 10, - cancel: cancel, - } + splitR := splitReader{ + Reader: bytes.NewReader(fileContent), + cancelAt: 1 << 10, + cancel: cancel, + } - code, resultCh, err := c0.SendFile(sendCtx, "file.txt", &splitR) - if err != nil { - t.Fatal(err) - } + code, resultCh, err := c0.SendFile(sendCtx, "file.txt", &splitR) + if err != nil { + t.Fatal(err) + } - receiver, err := c1.Receive(ctx, code) - if err != nil { - t.Fatal(err) - } + receiver, err := c1.Receive(ctx, code) + if err != nil { + t.Fatal(err) + } - gotMsg, err := ioutil.ReadAll(receiver) - if err == nil { - t.Fatalf("Expected read error but got none. got msg size: %d, orig_size: %d, cancel_at: %de", len(gotMsg), len(fileContent), splitR.cancelAt) - } + _, err = ioutil.ReadAll(receiver) + if err == nil { + t.Fatal("Expected read error but got none") + } - result := <-resultCh - if result.OK { - t.Fatal("Expected send resultCh to error but got none") + result := <-resultCh + if result.OK { + t.Fatal("Expected send resultCh to error but got none") + } + }) } } @@ -475,70 +502,72 @@ func TestPendingSendCancelable(t *testing.T) { testDisableLocalListener = true defer func() { testDisableLocalListener = false }() - relayServer := newTestRelayServer() - defer relayServer.close() - - c0 := Client{ - RendezvousURL: url, - TransitRelayAddress: relayServer.addr, - } + for relayProtocol, newRelayServer := range relayServerConstructors { + t.Run(fmt.Sprintf("With %s relay server", relayProtocol), func(t *testing.T) { + relayServer := newRelayServer() + defer relayServer.close() - fileContent := make([]byte, 1<<16) - for i := 0; i < len(fileContent); i++ { - fileContent[i] = byte(i) - } + c0 := Client{ + RendezvousURL: url, + TransitRelayURL: relayServer.url.String(), + } - buf := bytes.NewReader(fileContent) + fileContent := make([]byte, 1<<16) + for i := 0; i < len(fileContent); i++ { + fileContent[i] = byte(i) + } - childCtx, cancel := context.WithCancel(ctx) - defer cancel() + buf := bytes.NewReader(fileContent) - code, resultCh, err := c0.SendFile(childCtx, "file.txt", buf) - if err != nil { - t.Fatal(err) - } + childCtx, cancel := context.WithCancel(ctx) + defer cancel() - // connect to mailbox to wait for c0 to write its initial message - rc := rendezvous.NewClient(url, crypto.RandSideID(), c0.appID()) + code, resultCh, err := c0.SendFile(childCtx, "file.txt", buf) + if err != nil { + t.Fatal(err) + } - _, err = rc.Connect(ctx) - if err != nil { - t.Fatal(err) - } + // connect to mailbox to wait for c0 to write its initial message + rc := rendezvous.NewClient(url, crypto.RandSideID(), c0.appID()) - defer rc.Close(ctx, rendezvous.Happy) - nameplate, err := nameplateFromCode(code) - if err != nil { - t.Fatal(err) - } + _, err = rc.Connect(ctx) + if err != nil { + t.Fatal(err) + } - err = rc.AttachMailbox(ctx, nameplate) - if err != nil { - t.Fatal(err) - } + defer rc.Close(ctx, rendezvous.Happy) + nameplate, err := nameplateFromCode(code) + if err != nil { + t.Fatal(err) + } - msgs := rc.MsgChan(ctx) + err = rc.AttachMailbox(ctx, nameplate) + if err != nil { + t.Fatal(err) + } - select { - case <-msgs: - case <-time.After(5 * time.Second): - t.Fatal("timeout waiting for c0 to send a message") - } + msgs := rc.MsgChan(ctx) - cancel() + select { + case <-msgs: + case <-time.After(5 * time.Second): + t.Fatal("timeout waiting for c0 to send a message") + } - select { - case result := <-resultCh: - if result.OK { - t.Fatalf("Expected cancellation error but got OK") - } - if result.Error == nil { - t.Fatalf("Expected cancellation error") - } - case <-time.After(5 * time.Second): - // log all goroutines - pprof.Lookup("goroutine").WriteTo(os.Stderr, 1) - t.Fatalf("Wait for result timed out") + cancel() + + select { + case result := <-resultCh: + if result.OK { + t.Fatalf("Expected cancellation error but got OK") + } + if result.Error == nil { + t.Fatalf("Expected cancellation error") + } + case <-time.After(5 * time.Second): + t.Fatalf("Wait for result timed out") + } + }) } } @@ -553,76 +582,78 @@ func TestPendingRecvCancelable(t *testing.T) { testDisableLocalListener = true defer func() { testDisableLocalListener = false }() - relayServer := newTestRelayServer() - defer relayServer.close() + for relayProtocol, newRelayServer := range relayServerConstructors { + t.Run(fmt.Sprintf("With %s relay server", relayProtocol), func(t *testing.T) { + relayServer := newRelayServer() + defer relayServer.close() - c0 := Client{ - RendezvousURL: url, - TransitRelayAddress: relayServer.addr, - } + c0 := Client{ + RendezvousURL: url, + TransitRelayURL: relayServer.url.String(), + } - childCtx, cancel := context.WithCancel(ctx) - defer cancel() + childCtx, cancel := context.WithCancel(ctx) + defer cancel() - code := "87-firetrap-fallacy" - resultCh := make(chan error, 1) - go func() { - _, err := c0.Receive(childCtx, code) - resultCh <- err - }() + code := "87-firetrap-fallacy" + resultCh := make(chan error, 1) + go func() { + _, err := c0.Receive(childCtx, code) + resultCh <- err + }() - // wait to see mailbox has been allocated, and then - // wait to see PAKE message from receiver - rc := rendezvous.NewClient(url, crypto.RandSideID(), c0.appID()) + // wait to see mailbox has been allocated, and then + // wait to see PAKE message from receiver + rc := rendezvous.NewClient(url, crypto.RandSideID(), c0.appID()) - _, err := rc.Connect(ctx) - if err != nil { - t.Fatal(err) - } - - defer rc.Close(ctx, rendezvous.Happy) + _, err := rc.Connect(ctx) + if err != nil { + t.Fatal(err) + } - for i := 0; i < 20; i++ { - nameplates, err := rc.ListNameplates(ctx) - if err != nil { - t.Fatal(err) - } - if len(nameplates) > 0 { - break - } - time.Sleep(5 * time.Millisecond) - } + defer rc.Close(ctx, rendezvous.Happy) + + for i := 0; i < 20; i++ { + nameplates, err := rc.ListNameplates(ctx) + if err != nil { + t.Fatal(err) + } + if len(nameplates) > 0 { + break + } + time.Sleep(5 * time.Millisecond) + } - defer rc.Close(ctx, rendezvous.Happy) - nameplate, err := nameplateFromCode(code) - if err != nil { - t.Fatal(err) - } + defer rc.Close(ctx, rendezvous.Happy) + nameplate, err := nameplateFromCode(code) + if err != nil { + t.Fatal(err) + } - err = rc.AttachMailbox(ctx, nameplate) - if err != nil { - t.Fatal(err) - } + err = rc.AttachMailbox(ctx, nameplate) + if err != nil { + t.Fatal(err) + } - msgs := rc.MsgChan(ctx) + msgs := rc.MsgChan(ctx) - select { - case <-msgs: - case <-time.After(5 * time.Second): - t.Fatal("timeout waiting for c0 to send a message") - } + select { + case <-msgs: + case <-time.After(5 * time.Second): + t.Fatal("timeout waiting for c0 to send a message") + } - cancel() + cancel() - select { - case gotErr := <-resultCh: - if gotErr == nil { - t.Fatalf("Expected an error but got none") - } - case <-time.After(5 * time.Second): - // log all goroutines - pprof.Lookup("goroutine").WriteTo(os.Stderr, 1) - t.Fatalf("Timeout waiting for recv cancel") + select { + case gotErr := <-resultCh: + if gotErr == nil { + t.Fatalf("Expected an error but got none") + } + case <-time.After(5 * time.Second): + t.Fatalf("Timeout waiting for recv cancel") + } + }) } } @@ -635,7 +666,7 @@ func TestWormholeDirectoryTransportSendRecvDirect(t *testing.T) { url := rs.WebSocketURL() // disable transit relay for this test - DefaultTransitRelayAddress = "" + DefaultTransitRelayURL = "" var c0Verifier string var c0 Client @@ -743,18 +774,18 @@ func TestSendRecvEmptyFileDirect(t *testing.T) { url := rs.WebSocketURL() - DefaultTransitRelayAddress = "" + DefaultTransitRelayURL = "" - relayServer := newTestRelayServer() + relayServer := newTestTCPRelayServer() defer relayServer.close() var c0 Client c0.RendezvousURL = url - c0.TransitRelayAddress = relayServer.addr + c0.TransitRelayURL = relayServer.url.String() var c1 Client c1.RendezvousURL = url - c1.TransitRelayAddress = relayServer.addr + c1.TransitRelayURL = relayServer.url.String() fileContent := make([]byte, 0) buf := bytes.NewReader(fileContent) @@ -795,16 +826,16 @@ func TestSendRecvEmptyFileViaRelay(t *testing.T) { testDisableLocalListener = true defer func() { testDisableLocalListener = false }() - relayServer := newTestRelayServer() + relayServer := newTestTCPRelayServer() defer relayServer.close() var c0 Client c0.RendezvousURL = url - c0.TransitRelayAddress = relayServer.addr + c0.TransitRelayURL = relayServer.url.String() var c1 Client c1.RendezvousURL = url - c1.TransitRelayAddress = relayServer.addr + c1.TransitRelayURL = relayServer.url.String() fileContent := make([]byte, 0) @@ -835,23 +866,146 @@ func TestSendRecvEmptyFileViaRelay(t *testing.T) { } } +func TestWormholeDirectoryTransportSendRecvRelay(t *testing.T) { + ctx := context.Background() + + rs := rendezvousservertest.NewServer() + defer rs.Close() + + url := rs.WebSocketURL() + + for relayProtocol, newRelayServer := range relayServerConstructors { + t.Run(fmt.Sprintf("With %s relay server", relayProtocol), func(t *testing.T) { + relayServer := newRelayServer() + defer relayServer.close() + relayURL := relayServer.url.String() + + var c0Verifier string + var c0 Client + c0.RendezvousURL = url + c0.TransitRelayURL = relayURL + c0.VerifierOk = func(code string) bool { + c0Verifier = code + return true + } + + var c1Verifier string + var c1 Client + c1.RendezvousURL = url + c1.TransitRelayURL = relayURL + c1.VerifierOk = func(code string) bool { + c1Verifier = code + return true + } + + personalizeContent := make([]byte, 1<<16) + for i := 0; i < len(personalizeContent); i++ { + personalizeContent[i] = byte(i) + } + + bodiceContent := []byte("placarding-whereat") + + entries := []DirectoryEntry{ + { + Path: filepath.Join("skyjacking", "personalize.txt"), + Reader: func() (io.ReadCloser, error) { + b := bytes.NewReader(personalizeContent) + return ioutil.NopCloser(b), nil + }, + }, + { + Path: filepath.Join("skyjacking", "bodice-Maytag.txt"), + Reader: func() (io.ReadCloser, error) { + b := bytes.NewReader(bodiceContent) + return ioutil.NopCloser(b), nil + }, + }, + } + + code, resultCh, err := c0.SendDirectory(ctx, "skyjacking", entries) + if err != nil { + t.Fatal(err) + } + + receiver, err := c1.Receive(ctx, code) + if err != nil { + t.Fatal(err) + } + + got, err := ioutil.ReadAll(receiver) + if err != nil { + t.Fatal(err) + } + + r, err := zip.NewReader(bytes.NewReader(got), int64(len(got))) + if err != nil { + t.Fatal(err) + } + + for _, f := range r.File { + rc, err := f.Open() + if err != nil { + t.Fatal(err) + } + body, err := ioutil.ReadAll(rc) + if err != nil { + t.Fatal(err) + } + rc.Close() + + if f.Name == "personalize.txt" { + if !bytes.Equal(body, personalizeContent) { + t.Fatal("personalize.txt file content does not match") + } + } else if f.Name == "bodice-Maytag.txt" { + if !bytes.Equal(bodiceContent, body) { + t.Fatalf("bodice-Maytag.txt file content does not match %s vs %s", bodiceContent, body) + } + } else { + t.Fatalf("Unexpected file %s", f.Name) + } + } + + result := <-resultCh + if !result.OK { + t.Fatalf("Expected ok result but got: %+v", result) + } + + if c0Verifier == "" || c1Verifier == "" { + t.Fatalf("Failed to get verifier code c0=%q c1=%q", c0Verifier, c1Verifier) + } + + if c0Verifier != c1Verifier { + t.Fatalf("Expected verifiers to match but were different") + } + }) + } +} + type testRelayServer struct { + *httptest.Server l net.Listener - addr string + url *url.URL + proto string wg sync.WaitGroup mu sync.Mutex streams map[string]net.Conn } -func newTestRelayServer() *testRelayServer { - l, err := net.Listen("tcp", ":0") +func newTestTCPRelayServer() *testRelayServer { + l, err := net.Listen("tcp4", ":0") if err != nil { panic(err) } + url, err := url.Parse("tcp:" + l.Addr().String()) + if err != nil { + panic(err) + } rs := &testRelayServer{ l: l, - addr: l.Addr().String(), + url: url, + proto: "tcp", streams: make(map[string]net.Conn), } @@ -876,6 +1030,39 @@ func (ts *testRelayServer) run() { } } +func newTestWSRelayServer() *testRelayServer { + rs := &testRelayServer{ + proto: "ws", + streams: make(map[string]net.Conn), + } + + smux := http.NewServeMux() + smux.HandleFunc("/", rs.handleWSRelay) + + rs.Server = httptest.NewServer(smux) + url, err := url.Parse("ws://" + rs.Server.Listener.Addr().String()) + if err != nil { + panic(err) + } + rs.url = url + rs.l = rs.Server.Listener + + return rs +} + +func (rs *testRelayServer) handleWSRelay(w http.ResponseWriter, r *http.Request) { + c, err := websocket.Accept(w, r, nil) + + if err != nil { + return + } + + ctx := context.Background() + conn := websocket.NetConn(ctx, c, websocket.MessageBinary) + rs.wg.Add(1) + go rs.handleConn(conn) +} + var headerPrefix = []byte("please relay ") var headerSide = []byte(" for side ") @@ -974,6 +1161,75 @@ func (ts *testRelayServer) handleConn(c net.Conn) { } } +func TestClient_relayURL_default(t *testing.T) { + var c Client + + DefaultTransitRelayURL = "tcp:transit.magic-wormhole.io:8001" + url, err := c.relayURL() + if err != nil { + t.Error(err) + return + } + if url.Scheme != "tcp" { + t.Error(fmt.Sprintf("invalid protocol, expected tcp, got %v", url)) + } +} + +func TestWormholeFileTransportSendRecvViaWSRelayServer(t *testing.T) { + ctx := context.Background() + + rs := rendezvousservertest.NewServer() + defer rs.Close() + + url := rs.WebSocketURL() + + testDisableLocalListener = true + defer func() { testDisableLocalListener = false }() + + relayServer := newTestWSRelayServer() + relayURL := relayServer.url.String() + defer relayServer.close() + + var c0 Client + c0.RendezvousURL = url + c0.TransitRelayURL = relayURL + + var c1 Client + c1.RendezvousURL = url + c1.TransitRelayURL = relayURL + + fileContent := make([]byte, 1<<16) + for i := 0; i < len(fileContent); i++ { + fileContent[i] = byte(i) + } + + buf := bytes.NewReader(fileContent) + + code, resultCh, err := c0.SendFile(ctx, "file.txt", buf) + if err != nil { + t.Fatal(err) + } + + receiver, err := c1.Receive(ctx, code) + if err != nil { + t.Fatal(err) + } + + got, err := ioutil.ReadAll(receiver) + if err != nil { + t.Fatal(err) + } + + if !bytes.Equal(got, fileContent) { + t.Fatalf("File contents mismatch") + } + + result := <-resultCh + if !result.OK { + t.Fatalf("Expected ok result but got: %+v", result) + } +} + type splitReader struct { *bytes.Reader offset int