-
Notifications
You must be signed in to change notification settings - Fork 61
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add websocket transport to transit relay #63
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -9,15 +9,18 @@ import ( | |
"errors" | ||
"fmt" | ||
"io" | ||
"log" | ||
"math" | ||
"math/big" | ||
"net" | ||
"strconv" | ||
"time" | ||
"net/url" | ||
|
||
"github.com/psanford/wormhole-william/internal/crypto" | ||
"golang.org/x/crypto/hkdf" | ||
"golang.org/x/crypto/nacl/secretbox" | ||
"nhooyr.io/websocket" | ||
) | ||
|
||
type fileTransportAck struct { | ||
|
@@ -33,6 +36,10 @@ const ( | |
TransferText | ||
) | ||
|
||
// UnsupportedProtocolErr is used in the default case of protocol switch | ||
// statements to account for unexpected protocols. | ||
var UnsupportedProtocolErr = errors.New("unsupported protocol") | ||
|
||
func (tt TransferType) String() string { | ||
switch tt { | ||
case TransferFile: | ||
|
@@ -154,18 +161,18 @@ func (d *transportCryptor) writeRecord(msg []byte) error { | |
return err | ||
} | ||
|
||
func newFileTransport(transitKey []byte, appID, relayAddr string) *fileTransport { | ||
func newFileTransport(transitKey []byte, appID string, relayURL *url.URL) *fileTransport { | ||
return &fileTransport{ | ||
transitKey: transitKey, | ||
appID: appID, | ||
relayAddr: relayAddr, | ||
relayURL: relayURL, | ||
} | ||
} | ||
|
||
type fileTransport struct { | ||
listener net.Listener | ||
relayConn net.Conn | ||
relayAddr string | ||
relayURL *url.URL | ||
transitKey []byte | ||
appID string | ||
} | ||
|
@@ -177,19 +184,24 @@ func (t *fileTransport) connectViaRelay(otherTransit *transitMsg) (net.Conn, err | |
failChan := make(chan string) | ||
|
||
var count int | ||
|
||
for _, outerHint := range otherTransit.HintsV1 { | ||
if outerHint.Type == "relay-v1" { | ||
for _, innerHint := range outerHint.Hints { | ||
if innerHint.Type == "direct-tcp-v1" { | ||
count++ | ||
ctx, cancel := context.WithCancel(context.Background()) | ||
addr := net.JoinHostPort(innerHint.Hostname, strconv.Itoa(innerHint.Port)) | ||
|
||
cancelFuncs[addr] = cancel | ||
|
||
go t.connectToRelay(ctx, addr, successChan, failChan) | ||
for _, relay := range otherTransit.HintsV1 { | ||
log.Println("relay: ", relay) | ||
if relay.Type == "relay-v1" { | ||
for _, endpoint := range relay.Hints { | ||
log.Println("- endpoint: ", endpoint) | ||
var addr string | ||
switch endpoint.Type { | ||
case "direct-tcp-v1": | ||
addr = net.JoinHostPort(endpoint.Hostname, strconv.Itoa(endpoint.Port)) | ||
case "websocket": | ||
addr = endpoint.Url | ||
} | ||
ctx, cancel := context.WithCancel(context.Background()) | ||
cancelFuncs[addr] = cancel | ||
log.Println("- addr: ", addr) | ||
|
||
count++ | ||
go t.connectToRelay(ctx, successChan, failChan) | ||
} | ||
} | ||
} | ||
|
@@ -250,34 +262,57 @@ func (t *fileTransport) connectDirect(otherTransit *transitMsg) (net.Conn, error | |
return conn, nil | ||
} | ||
|
||
func (t *fileTransport) connectToRelay(ctx context.Context, addr string, successChan chan net.Conn, failChan chan string) { | ||
func (t *fileTransport) connectToRelay(ctx context.Context, successChan chan net.Conn, failChan chan string) { | ||
var d net.Dialer | ||
conn, err := d.DialContext(ctx, "tcp", addr) | ||
if err != nil { | ||
failChan <- addr | ||
return | ||
var conn net.Conn | ||
var err error | ||
addr := fmt.Sprintf("%s:%s", t.relayURL.Hostname(), t.relayURL.Port()) | ||
|
||
log.Println("- Relay: ", t.relayURL) | ||
switch t.relayURL.Scheme { | ||
case "tcp": | ||
conn, err = d.DialContext(ctx, "tcp", addr) | ||
log.Println(" - ", conn, err) | ||
|
||
if err != nil { | ||
failChan <- addr | ||
return | ||
} | ||
case "ws", "wss": | ||
var wsconn *websocket.Conn | ||
wsconn, _, err = websocket.Dial(ctx, t.relayURL.String(), nil) | ||
|
||
if err != nil { | ||
failChan <- addr | ||
return | ||
} | ||
|
||
conn = websocket.NetConn(ctx, wsconn, websocket.MessageBinary) | ||
} | ||
|
||
_, err = conn.Write(t.relayHandshakeHeader()) | ||
log.Println(" - ", err) | ||
if err != nil { | ||
failChan <- addr | ||
return | ||
} | ||
gotOk := make([]byte, 3) | ||
_, err = io.ReadFull(conn, gotOk) | ||
log.Println(" - ", err) | ||
if err != nil { | ||
conn.Close() | ||
failChan <- addr | ||
return | ||
} | ||
|
||
if !bytes.Equal(gotOk, []byte("ok\n")) { | ||
log.Println(" - Not OK") | ||
conn.Close() | ||
failChan <- addr | ||
return | ||
} | ||
|
||
t.directRecvHandshake(ctx, addr, conn, successChan, failChan) | ||
t.directRecvHandshake(ctx, conn, successChan, failChan) | ||
} | ||
|
||
func (t *fileTransport) connectToSingleHost(ctx context.Context, addr string, successChan chan net.Conn, failChan chan string) { | ||
|
@@ -289,12 +324,13 @@ func (t *fileTransport) connectToSingleHost(ctx context.Context, addr string, su | |
return | ||
} | ||
|
||
t.directRecvHandshake(ctx, addr, conn, successChan, failChan) | ||
t.directRecvHandshake(ctx, conn, successChan, failChan) | ||
} | ||
|
||
func (t *fileTransport) directRecvHandshake(ctx context.Context, addr string, conn net.Conn, successChan chan net.Conn, failChan chan string) { | ||
func (t *fileTransport) directRecvHandshake(ctx context.Context, conn net.Conn, successChan chan net.Conn, failChan chan string) { | ||
expectHeader := t.senderHandshakeHeader() | ||
|
||
addr := t.relayURL.Hostname() | ||
gotHeader := make([]byte, len(expectHeader)) | ||
|
||
_, err := io.ReadFull(conn, gotHeader) | ||
|
@@ -372,27 +408,43 @@ func (t *fileTransport) makeTransitMsg() (*transitMsg, error) { | |
} | ||
|
||
if t.relayConn != nil { | ||
relayHost, portStr, err := net.SplitHostPort(t.relayAddr) | ||
if err != nil { | ||
return nil, err | ||
var relayType string | ||
switch t.relayURL.Scheme { | ||
case "tcp": | ||
relayType = "direct-tcp-v1" | ||
case "ws": | ||
relayType = "websocket-v1" | ||
case "wss": | ||
relayType = "websocket-v1" | ||
default: | ||
return nil, fmt.Errorf("%w: %s", UnsupportedProtocolErr, t.relayURL.Scheme) | ||
} | ||
|
||
relayPort, err := strconv.Atoi(portStr) | ||
if err != nil { | ||
return nil, fmt.Errorf("port isn't an integer? %s", portStr) | ||
} | ||
|
||
msg.HintsV1 = append(msg.HintsV1, transitHintsV1{ | ||
Type: "relay-v1", | ||
Hints: []transitHintsV1Hint{ | ||
{ | ||
Type: "direct-tcp-v1", | ||
Priority: 2.0, | ||
Hostname: relayHost, | ||
Port: relayPort, | ||
if relayType == "direct-tcp-v1" { | ||
var port, err = strconv.Atoi(t.relayURL.Port()) | ||
if err != nil { | ||
return nil, fmt.Errorf("invalid port") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it would make more sense to use |
||
} | ||
msg.HintsV1 = append(msg.HintsV1, transitHintsV1{ | ||
Type: "relay-v1", | ||
Hints: []transitHintsRelay{ | ||
{ | ||
Type: relayType, | ||
Hostname: t.relayURL.Hostname(), | ||
Port: port, | ||
}, | ||
}, | ||
}, | ||
}) | ||
}) | ||
} else { | ||
msg.HintsV1 = append(msg.HintsV1, transitHintsV1{ | ||
Type: "relay-v1", | ||
Hints: []transitHintsRelay{ | ||
{ | ||
Type: relayType, | ||
Url: t.relayURL.String(), | ||
}, | ||
}, | ||
}) | ||
} | ||
} | ||
|
||
return &msg, nil | ||
|
@@ -449,23 +501,45 @@ func (t *fileTransport) listen() error { | |
if testDisableLocalListener { | ||
return nil | ||
} | ||
switch t.relayURL.Scheme { | ||
case "tcp": | ||
l, err := net.Listen("tcp", ":0") | ||
if err != nil { | ||
return err | ||
} | ||
|
||
l, err := net.Listen("tcp", ":0") | ||
if err != nil { | ||
return err | ||
t.listener = l | ||
case "ws", "wss": | ||
t.listener = nil | ||
default: | ||
return fmt.Errorf("%w: %s", UnsupportedProtocolErr, t.relayURL.Scheme) | ||
} | ||
|
||
t.listener = l | ||
return nil | ||
} | ||
|
||
func (t *fileTransport) listenRelay() error { | ||
if t.relayAddr == "" { | ||
return nil | ||
} | ||
conn, err := net.Dial("tcp", t.relayAddr) | ||
if err != nil { | ||
return err | ||
func (t *fileTransport) listenRelay(ctx context.Context) (err error) { | ||
var conn net.Conn | ||
log.Println("URL: ", t.relayURL) | ||
switch t.relayURL.Scheme { | ||
case "tcp": | ||
// NB: don't dial the relay if we don't have an address. | ||
addr := t.relayURL.Host | ||
if addr == ":0" { | ||
return nil | ||
} | ||
conn, err = net.Dial("tcp", addr) | ||
if err != nil { | ||
return err | ||
} | ||
case "ws", "wss": | ||
c, _, err := websocket.Dial(ctx, t.relayURL.String(), nil) | ||
if err != nil { | ||
return fmt.Errorf("websocket.Dial failed") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same as my comment above. |
||
} | ||
conn = websocket.NetConn(ctx, c, websocket.MessageBinary) | ||
default: | ||
return fmt.Errorf("%w: %s", UnsupportedProtocolErr, t.relayURL.Scheme) | ||
} | ||
|
||
_, err = conn.Write(t.relayHandshakeHeader()) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -157,7 +157,11 @@ func (c *Client) Receive(ctx context.Context, code string) (fr *IncomingMessage, | |
} | ||
|
||
transitKey := deriveTransitKey(clientProto.sharedKey, appID) | ||
transport := newFileTransport(transitKey, appID, c.relayAddr()) | ||
relayUrl, err := c.relayURL() | ||
if err != nil { | ||
return nil, fmt.Errorf("Invalid relay URL") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same as my comment above. Also, should this one perhaps be moved out to a global variable given that it's used in more than one place? |
||
} | ||
transport := newFileTransport(transitKey, appID, relayUrl) | ||
|
||
transitMsg, err := transport.makeTransitMsg() | ||
if err != nil { | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -185,10 +185,6 @@ func (c *Client) SendText(ctx context.Context, msg string, opts ...SendOption) ( | |
} | ||
|
||
func (c *Client) sendFileDirectory(ctx context.Context, offer *offerMsg, r io.Reader, opts ...SendOption) (string, chan SendResult, error) { | ||
if err := c.validateRelayAddr(); err != nil { | ||
return "", nil, fmt.Errorf("invalid TransitRelayAddress: %s", err) | ||
} | ||
|
||
var options sendOptions | ||
for _, opt := range opts { | ||
err := opt.setOption(&options) | ||
|
@@ -296,15 +292,20 @@ func (c *Client) sendFileDirectory(ctx context.Context, offer *offerMsg, r io.Re | |
} | ||
} | ||
|
||
var relayUrl, err = c.relayURL() | ||
if err != nil { | ||
sendErr(fmt.Errorf("Invalid relay URL")) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same as my comment above. |
||
return | ||
} | ||
transitKey := deriveTransitKey(clientProto.sharedKey, appID) | ||
transport := newFileTransport(transitKey, appID, c.relayAddr()) | ||
transport := newFileTransport(transitKey, appID, relayUrl) | ||
err = transport.listen() | ||
if err != nil { | ||
sendErr(err) | ||
return | ||
} | ||
|
||
err = transport.listenRelay() | ||
err = transport.listenRelay(ctx) | ||
if err != nil { | ||
sendErr(err) | ||
return | ||
|
@@ -443,10 +444,6 @@ func (c *Client) sendFileDirectory(ctx context.Context, offer *offerMsg, r io.Re | |
// receiver, a result channel that will be written to after the receiver attempts to read (either successfully or not) | ||
// and an error if one occurred. | ||
func (c *Client) SendFile(ctx context.Context, fileName string, r io.ReadSeeker, opts ...SendOption) (string, chan SendResult, error) { | ||
if err := c.validateRelayAddr(); err != nil { | ||
return "", nil, fmt.Errorf("invalid TransitRelayAddress: %s", err) | ||
} | ||
|
||
size, err := readSeekerSize(r) | ||
if err != nil { | ||
return "", nil, err | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So, switching to
net/url
would mean, thetcp
based relay urls would also be encoded astcp://host:port
, right? Would that break compatibility with other wormhole clients that still follow the twisted url format? In that case, we should note that in theREADME
perhaps?