Skip to content

Commit

Permalink
Implement WebSocket relay support
Browse files Browse the repository at this point in the history
Co-authored-by: Ramakrishnan Muthukrishnan <ram@leastauthority.com>
Co-authored-by: Bryan White <bryanchriswhite@gmail.com>
  • Loading branch information
3 people committed Jun 10, 2022
1 parent 83b8359 commit c998891
Show file tree
Hide file tree
Showing 5 changed files with 669 additions and 349 deletions.
170 changes: 117 additions & 53 deletions wormhole/file_transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,12 @@ import (
"net"
"strconv"
"time"
"net/url"

"github.com/psanford/wormhole-william/internal/crypto"
"golang.org/x/crypto/hkdf"
"golang.org/x/crypto/nacl/secretbox"
"nhooyr.io/websocket"
)

type fileTransportAck struct {
Expand All @@ -33,6 +35,10 @@ const (
TransferText
)

// UnsupportedProtocolErr is used in the default case of protocol switch
// statements to account for unexpected protocols.
var UnsupportedProtocolErr = errors.New("unsupported protocol")

func (tt TransferType) String() string {
switch tt {
case TransferFile:
Expand Down Expand Up @@ -154,18 +160,18 @@ func (d *transportCryptor) writeRecord(msg []byte) error {
return err
}

func newFileTransport(transitKey []byte, appID, relayAddr string) *fileTransport {
func newFileTransport(transitKey []byte, appID string, relayURL *url.URL) *fileTransport {
return &fileTransport{
transitKey: transitKey,
appID: appID,
relayAddr: relayAddr,
relayURL: relayURL,
}
}

type fileTransport struct {
listener net.Listener
relayConn net.Conn
relayAddr string
relayURL *url.URL
transitKey []byte
appID string
}
Expand All @@ -177,19 +183,21 @@ func (t *fileTransport) connectViaRelay(otherTransit *transitMsg) (net.Conn, err
failChan := make(chan string)

var count int

for _, outerHint := range otherTransit.HintsV1 {
if outerHint.Type == "relay-v1" {
for _, innerHint := range outerHint.Hints {
if innerHint.Type == "direct-tcp-v1" {
count++
ctx, cancel := context.WithCancel(context.Background())
addr := net.JoinHostPort(innerHint.Hostname, strconv.Itoa(innerHint.Port))

cancelFuncs[addr] = cancel

go t.connectToRelay(ctx, addr, successChan, failChan)
for _, relay := range otherTransit.HintsV1 {
if relay.Type == "relay-v1" {
for _, endpoint := range relay.Hints {
var addr string
switch endpoint.Type {
case "direct-tcp-v1":
addr = net.JoinHostPort(endpoint.Hostname, strconv.Itoa(endpoint.Port))
case "websocket":
addr = endpoint.Url
}
ctx, cancel := context.WithCancel(context.Background())
cancelFuncs[addr] = cancel

count++
go t.connectToRelay(ctx, successChan, failChan)
}
}
}
Expand Down Expand Up @@ -250,12 +258,30 @@ func (t *fileTransport) connectDirect(otherTransit *transitMsg) (net.Conn, error
return conn, nil
}

func (t *fileTransport) connectToRelay(ctx context.Context, addr string, successChan chan net.Conn, failChan chan string) {
func (t *fileTransport) connectToRelay(ctx context.Context, successChan chan net.Conn, failChan chan string) {
var d net.Dialer
conn, err := d.DialContext(ctx, "tcp", addr)
if err != nil {
failChan <- addr
return
var conn net.Conn
var err error
addr := fmt.Sprintf("%s:%s", t.relayURL.Hostname(), t.relayURL.Port())

switch t.relayURL.Scheme {
case "tcp":
conn, err = d.DialContext(ctx, "tcp", addr)

if err != nil {
failChan <- addr
return
}
case "ws", "wss":
var wsconn *websocket.Conn
wsconn, _, err = websocket.Dial(ctx, t.relayURL.String(), nil)

if err != nil {
failChan <- addr
return
}

conn = websocket.NetConn(ctx, wsconn, websocket.MessageBinary)
}

_, err = conn.Write(t.relayHandshakeHeader())
Expand All @@ -277,7 +303,7 @@ func (t *fileTransport) connectToRelay(ctx context.Context, addr string, success
return
}

t.directRecvHandshake(ctx, addr, conn, successChan, failChan)
t.directRecvHandshake(ctx, conn, successChan, failChan)
}

func (t *fileTransport) connectToSingleHost(ctx context.Context, addr string, successChan chan net.Conn, failChan chan string) {
Expand All @@ -289,12 +315,13 @@ func (t *fileTransport) connectToSingleHost(ctx context.Context, addr string, su
return
}

t.directRecvHandshake(ctx, addr, conn, successChan, failChan)
t.directRecvHandshake(ctx, conn, successChan, failChan)
}

func (t *fileTransport) directRecvHandshake(ctx context.Context, addr string, conn net.Conn, successChan chan net.Conn, failChan chan string) {
func (t *fileTransport) directRecvHandshake(ctx context.Context, conn net.Conn, successChan chan net.Conn, failChan chan string) {
expectHeader := t.senderHandshakeHeader()

addr := t.relayURL.Hostname()
gotHeader := make([]byte, len(expectHeader))

_, err := io.ReadFull(conn, gotHeader)
Expand Down Expand Up @@ -372,27 +399,43 @@ func (t *fileTransport) makeTransitMsg() (*transitMsg, error) {
}

if t.relayConn != nil {
relayHost, portStr, err := net.SplitHostPort(t.relayAddr)
if err != nil {
return nil, err
var relayType string
switch t.relayURL.Scheme {
case "tcp":
relayType = "direct-tcp-v1"
case "ws":
relayType = "websocket"
case "wss":
relayType = "websocket"
default:
return nil, fmt.Errorf("%w: %s", UnsupportedProtocolErr, t.relayURL.Scheme)
}

relayPort, err := strconv.Atoi(portStr)
if err != nil {
return nil, fmt.Errorf("port isn't an integer? %s", portStr)
}

msg.HintsV1 = append(msg.HintsV1, transitHintsV1{
Type: "relay-v1",
Hints: []transitHintsV1Hint{
{
Type: "direct-tcp-v1",
Priority: 2.0,
Hostname: relayHost,
Port: relayPort,
if relayType == "direct-tcp-v1" {
var port, err = strconv.Atoi(t.relayURL.Port())
if err != nil {
return nil, fmt.Errorf("invalid port")
}
msg.HintsV1 = append(msg.HintsV1, transitHintsV1{
Type: "relay-v1",
Hints: []transitHintsRelay{
{
Type: relayType,
Hostname: t.relayURL.Hostname(),
Port: port,
},
},
},
})
})
} else {
msg.HintsV1 = append(msg.HintsV1, transitHintsV1{
Type: "relay-v1",
Hints: []transitHintsRelay{
{
Type: relayType,
Url: t.relayURL.String(),
},
},
})
}
}

return &msg, nil
Expand Down Expand Up @@ -449,23 +492,44 @@ func (t *fileTransport) listen() error {
if testDisableLocalListener {
return nil
}
switch t.relayURL.Scheme {
case "tcp":
l, err := net.Listen("tcp", ":0")
if err != nil {
return err
}

l, err := net.Listen("tcp", ":0")
if err != nil {
return err
t.listener = l
case "ws", "wss":
t.listener = nil
default:
return fmt.Errorf("%w: %s", UnsupportedProtocolErr, t.relayURL.Scheme)
}

t.listener = l
return nil
}

func (t *fileTransport) listenRelay() error {
if t.relayAddr == "" {
return nil
}
conn, err := net.Dial("tcp", t.relayAddr)
if err != nil {
return err
func (t *fileTransport) listenRelay(ctx context.Context) (err error) {
var conn net.Conn
switch t.relayURL.Scheme {
case "tcp":
// NB: don't dial the relay if we don't have an address.
addr := t.relayURL.Host
if addr == ":0" {
return nil
}
conn, err = net.Dial("tcp", addr)
if err != nil {
return err
}
case "ws", "wss":
c, _, err := websocket.Dial(ctx, t.relayURL.String(), nil)
if err != nil {
return fmt.Errorf("websocket.Dial failed")
}
conn = websocket.NetConn(ctx, c, websocket.MessageBinary)
default:
return fmt.Errorf("%w: %s", UnsupportedProtocolErr, t.relayURL.Scheme)
}

_, err = conn.Write(t.relayHandshakeHeader())
Expand Down
6 changes: 5 additions & 1 deletion wormhole/recv.go
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,11 @@ func (c *Client) Receive(ctx context.Context, code string) (fr *IncomingMessage,
}

transitKey := deriveTransitKey(clientProto.sharedKey, appID)
transport := newFileTransport(transitKey, appID, c.relayAddr())
relayUrl, err := c.relayURL()
if err != nil {
return nil, fmt.Errorf("Invalid relay URL")
}
transport := newFileTransport(transitKey, appID, relayUrl)

transitMsg, err := transport.makeTransitMsg()
if err != nil {
Expand Down
17 changes: 7 additions & 10 deletions wormhole/send.go
Original file line number Diff line number Diff line change
Expand Up @@ -185,10 +185,6 @@ func (c *Client) SendText(ctx context.Context, msg string, opts ...SendOption) (
}

func (c *Client) sendFileDirectory(ctx context.Context, offer *offerMsg, r io.Reader, opts ...SendOption) (string, chan SendResult, error) {
if err := c.validateRelayAddr(); err != nil {
return "", nil, fmt.Errorf("invalid TransitRelayAddress: %s", err)
}

var options sendOptions
for _, opt := range opts {
err := opt.setOption(&options)
Expand Down Expand Up @@ -296,15 +292,20 @@ func (c *Client) sendFileDirectory(ctx context.Context, offer *offerMsg, r io.Re
}
}

var relayUrl, err = c.relayURL()
if err != nil {
sendErr(fmt.Errorf("Invalid relay URL"))
return
}
transitKey := deriveTransitKey(clientProto.sharedKey, appID)
transport := newFileTransport(transitKey, appID, c.relayAddr())
transport := newFileTransport(transitKey, appID, relayUrl)
err = transport.listen()
if err != nil {
sendErr(err)
return
}

err = transport.listenRelay()
err = transport.listenRelay(ctx)
if err != nil {
sendErr(err)
return
Expand Down Expand Up @@ -443,10 +444,6 @@ func (c *Client) sendFileDirectory(ctx context.Context, offer *offerMsg, r io.Re
// receiver, a result channel that will be written to after the receiver attempts to read (either successfully or not)
// and an error if one occurred.
func (c *Client) SendFile(ctx context.Context, fileName string, r io.ReadSeeker, opts ...SendOption) (string, chan SendResult, error) {
if err := c.validateRelayAddr(); err != nil {
return "", nil, fmt.Errorf("invalid TransitRelayAddress: %s", err)
}

size, err := readSeekerSize(r)
if err != nil {
return "", nil, err
Expand Down
Loading

0 comments on commit c998891

Please sign in to comment.