Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add websocket transport to transit relay #63

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
180 changes: 127 additions & 53 deletions wormhole/file_transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,18 @@ import (
"errors"
"fmt"
"io"
"log"
"math"
"math/big"
"net"
"strconv"
"time"
"net/url"

"github.com/psanford/wormhole-william/internal/crypto"
"golang.org/x/crypto/hkdf"
"golang.org/x/crypto/nacl/secretbox"
"nhooyr.io/websocket"
)

type fileTransportAck struct {
Expand All @@ -33,6 +36,10 @@ const (
TransferText
)

// UnsupportedProtocolErr is used in the default case of protocol switch
// statements to account for unexpected protocols.
var UnsupportedProtocolErr = errors.New("unsupported protocol")

func (tt TransferType) String() string {
switch tt {
case TransferFile:
Expand Down Expand Up @@ -154,18 +161,18 @@ func (d *transportCryptor) writeRecord(msg []byte) error {
return err
}

func newFileTransport(transitKey []byte, appID, relayAddr string) *fileTransport {
func newFileTransport(transitKey []byte, appID string, relayURL *url.URL) *fileTransport {
return &fileTransport{
transitKey: transitKey,
appID: appID,
relayAddr: relayAddr,
relayURL: relayURL,
}
}

type fileTransport struct {
listener net.Listener
relayConn net.Conn
relayAddr string
relayURL *url.URL
transitKey []byte
appID string
}
Expand All @@ -177,19 +184,24 @@ func (t *fileTransport) connectViaRelay(otherTransit *transitMsg) (net.Conn, err
failChan := make(chan string)

var count int

for _, outerHint := range otherTransit.HintsV1 {
if outerHint.Type == "relay-v1" {
for _, innerHint := range outerHint.Hints {
if innerHint.Type == "direct-tcp-v1" {
count++
ctx, cancel := context.WithCancel(context.Background())
addr := net.JoinHostPort(innerHint.Hostname, strconv.Itoa(innerHint.Port))

cancelFuncs[addr] = cancel

go t.connectToRelay(ctx, addr, successChan, failChan)
for _, relay := range otherTransit.HintsV1 {
log.Println("relay: ", relay)
if relay.Type == "relay-v1" {
for _, endpoint := range relay.Hints {
log.Println("- endpoint: ", endpoint)
var addr string
switch endpoint.Type {
case "direct-tcp-v1":
addr = net.JoinHostPort(endpoint.Hostname, strconv.Itoa(endpoint.Port))
case "websocket":
addr = endpoint.Url
}
ctx, cancel := context.WithCancel(context.Background())
cancelFuncs[addr] = cancel
log.Println("- addr: ", addr)

count++
go t.connectToRelay(ctx, successChan, failChan)
}
}
}
Expand Down Expand Up @@ -250,34 +262,57 @@ func (t *fileTransport) connectDirect(otherTransit *transitMsg) (net.Conn, error
return conn, nil
}

func (t *fileTransport) connectToRelay(ctx context.Context, addr string, successChan chan net.Conn, failChan chan string) {
func (t *fileTransport) connectToRelay(ctx context.Context, successChan chan net.Conn, failChan chan string) {
var d net.Dialer
conn, err := d.DialContext(ctx, "tcp", addr)
if err != nil {
failChan <- addr
return
var conn net.Conn
var err error
addr := fmt.Sprintf("%s:%s", t.relayURL.Hostname(), t.relayURL.Port())

log.Println("- Relay: ", t.relayURL)
switch t.relayURL.Scheme {
case "tcp":
conn, err = d.DialContext(ctx, "tcp", addr)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, switching to net/url would mean, the tcp based relay urls would also be encoded as tcp://host:port, right? Would that break compatibility with other wormhole clients that still follow the twisted url format? In that case, we should note that in the README perhaps?

log.Println(" - ", conn, err)

if err != nil {
failChan <- addr
return
}
case "ws", "wss":
var wsconn *websocket.Conn
wsconn, _, err = websocket.Dial(ctx, t.relayURL.String(), nil)

if err != nil {
failChan <- addr
return
}

conn = websocket.NetConn(ctx, wsconn, websocket.MessageBinary)
}

_, err = conn.Write(t.relayHandshakeHeader())
log.Println(" - ", err)
if err != nil {
failChan <- addr
return
}
gotOk := make([]byte, 3)
_, err = io.ReadFull(conn, gotOk)
log.Println(" - ", err)
if err != nil {
conn.Close()
failChan <- addr
return
}

if !bytes.Equal(gotOk, []byte("ok\n")) {
log.Println(" - Not OK")
conn.Close()
failChan <- addr
return
}

t.directRecvHandshake(ctx, addr, conn, successChan, failChan)
t.directRecvHandshake(ctx, conn, successChan, failChan)
}

func (t *fileTransport) connectToSingleHost(ctx context.Context, addr string, successChan chan net.Conn, failChan chan string) {
Expand All @@ -289,12 +324,13 @@ func (t *fileTransport) connectToSingleHost(ctx context.Context, addr string, su
return
}

t.directRecvHandshake(ctx, addr, conn, successChan, failChan)
t.directRecvHandshake(ctx, conn, successChan, failChan)
}

func (t *fileTransport) directRecvHandshake(ctx context.Context, addr string, conn net.Conn, successChan chan net.Conn, failChan chan string) {
func (t *fileTransport) directRecvHandshake(ctx context.Context, conn net.Conn, successChan chan net.Conn, failChan chan string) {
expectHeader := t.senderHandshakeHeader()

addr := t.relayURL.Hostname()
gotHeader := make([]byte, len(expectHeader))

_, err := io.ReadFull(conn, gotHeader)
Expand Down Expand Up @@ -372,27 +408,43 @@ func (t *fileTransport) makeTransitMsg() (*transitMsg, error) {
}

if t.relayConn != nil {
relayHost, portStr, err := net.SplitHostPort(t.relayAddr)
if err != nil {
return nil, err
var relayType string
switch t.relayURL.Scheme {
case "tcp":
relayType = "direct-tcp-v1"
case "ws":
relayType = "websocket-v1"
case "wss":
relayType = "websocket-v1"
default:
return nil, fmt.Errorf("%w: %s", UnsupportedProtocolErr, t.relayURL.Scheme)
}

relayPort, err := strconv.Atoi(portStr)
if err != nil {
return nil, fmt.Errorf("port isn't an integer? %s", portStr)
}

msg.HintsV1 = append(msg.HintsV1, transitHintsV1{
Type: "relay-v1",
Hints: []transitHintsV1Hint{
{
Type: "direct-tcp-v1",
Priority: 2.0,
Hostname: relayHost,
Port: relayPort,
if relayType == "direct-tcp-v1" {
var port, err = strconv.Atoi(t.relayURL.Port())
if err != nil {
return nil, fmt.Errorf("invalid port")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it would make more sense to use errors.New here instead.

}
msg.HintsV1 = append(msg.HintsV1, transitHintsV1{
Type: "relay-v1",
Hints: []transitHintsRelay{
{
Type: relayType,
Hostname: t.relayURL.Hostname(),
Port: port,
},
},
},
})
})
} else {
msg.HintsV1 = append(msg.HintsV1, transitHintsV1{
Type: "relay-v1",
Hints: []transitHintsRelay{
{
Type: relayType,
Url: t.relayURL.String(),
},
},
})
}
}

return &msg, nil
Expand Down Expand Up @@ -449,23 +501,45 @@ func (t *fileTransport) listen() error {
if testDisableLocalListener {
return nil
}
switch t.relayURL.Scheme {
case "tcp":
l, err := net.Listen("tcp", ":0")
if err != nil {
return err
}

l, err := net.Listen("tcp", ":0")
if err != nil {
return err
t.listener = l
case "ws", "wss":
t.listener = nil
default:
return fmt.Errorf("%w: %s", UnsupportedProtocolErr, t.relayURL.Scheme)
}

t.listener = l
return nil
}

func (t *fileTransport) listenRelay() error {
if t.relayAddr == "" {
return nil
}
conn, err := net.Dial("tcp", t.relayAddr)
if err != nil {
return err
func (t *fileTransport) listenRelay(ctx context.Context) (err error) {
var conn net.Conn
log.Println("URL: ", t.relayURL)
switch t.relayURL.Scheme {
case "tcp":
// NB: don't dial the relay if we don't have an address.
addr := t.relayURL.Host
if addr == ":0" {
return nil
}
conn, err = net.Dial("tcp", addr)
if err != nil {
return err
}
case "ws", "wss":
c, _, err := websocket.Dial(ctx, t.relayURL.String(), nil)
if err != nil {
return fmt.Errorf("websocket.Dial failed")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as my comment above.

}
conn = websocket.NetConn(ctx, c, websocket.MessageBinary)
default:
return fmt.Errorf("%w: %s", UnsupportedProtocolErr, t.relayURL.Scheme)
}

_, err = conn.Write(t.relayHandshakeHeader())
Expand Down
6 changes: 5 additions & 1 deletion wormhole/recv.go
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,11 @@ func (c *Client) Receive(ctx context.Context, code string) (fr *IncomingMessage,
}

transitKey := deriveTransitKey(clientProto.sharedKey, appID)
transport := newFileTransport(transitKey, appID, c.relayAddr())
relayUrl, err := c.relayURL()
if err != nil {
return nil, fmt.Errorf("Invalid relay URL")
Copy link
Contributor

@Jacalz Jacalz May 22, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as my comment above. Also, should this one perhaps be moved out to a global variable given that it's used in more than one place?

}
transport := newFileTransport(transitKey, appID, relayUrl)

transitMsg, err := transport.makeTransitMsg()
if err != nil {
Expand Down
17 changes: 7 additions & 10 deletions wormhole/send.go
Original file line number Diff line number Diff line change
Expand Up @@ -185,10 +185,6 @@ func (c *Client) SendText(ctx context.Context, msg string, opts ...SendOption) (
}

func (c *Client) sendFileDirectory(ctx context.Context, offer *offerMsg, r io.Reader, opts ...SendOption) (string, chan SendResult, error) {
if err := c.validateRelayAddr(); err != nil {
return "", nil, fmt.Errorf("invalid TransitRelayAddress: %s", err)
}

var options sendOptions
for _, opt := range opts {
err := opt.setOption(&options)
Expand Down Expand Up @@ -296,15 +292,20 @@ func (c *Client) sendFileDirectory(ctx context.Context, offer *offerMsg, r io.Re
}
}

var relayUrl, err = c.relayURL()
if err != nil {
sendErr(fmt.Errorf("Invalid relay URL"))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as my comment above.

return
}
transitKey := deriveTransitKey(clientProto.sharedKey, appID)
transport := newFileTransport(transitKey, appID, c.relayAddr())
transport := newFileTransport(transitKey, appID, relayUrl)
err = transport.listen()
if err != nil {
sendErr(err)
return
}

err = transport.listenRelay()
err = transport.listenRelay(ctx)
if err != nil {
sendErr(err)
return
Expand Down Expand Up @@ -443,10 +444,6 @@ func (c *Client) sendFileDirectory(ctx context.Context, offer *offerMsg, r io.Re
// receiver, a result channel that will be written to after the receiver attempts to read (either successfully or not)
// and an error if one occurred.
func (c *Client) SendFile(ctx context.Context, fileName string, r io.ReadSeeker, opts ...SendOption) (string, chan SendResult, error) {
if err := c.validateRelayAddr(); err != nil {
return "", nil, fmt.Errorf("invalid TransitRelayAddress: %s", err)
}

size, err := readSeekerSize(r)
if err != nil {
return "", nil, err
Expand Down
Loading