Skip to content

Commit

Permalink
Let the proxy generate relay connections via a callback
Browse files Browse the repository at this point in the history
  • Loading branch information
rg0now committed Nov 29, 2023
1 parent cdfa85a commit c8f02cc
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 46 deletions.
1 change: 1 addition & 0 deletions examples/turn-proxy/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ func main() {
TURNServerURI: fmt.Sprintf("turn:%s:%d?transport=udp", *host, *port),
Listeners: []net.Listener{listener},
PeerAddr: clientAddr,
RelayConnGen: turn.DefaultRelayConnGen(true),
AuthGen: func() (string, string, error) { return cred[0], cred[1], nil },
})
if err != nil {
Expand Down
103 changes: 57 additions & 46 deletions proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,53 @@ import (
// AuthGen is a callback used to generate TURN credentials.
type AuthGen func() (string, string, error)

// RelayConnGen is used to generate a PacketConns that the proxy can use to connect to the TURN server.
type RelayConnGen func(protocol, addr string) (net.PacketConn, error)

// DefaultRelayConnGen is a default relay connection generator that knows how to generate relay connections for the proxy. Set insecure to true to let the proxy accept self-signed server-side TLS certificates.
func DefaultRelayConnGen(insecure bool) RelayConnGen {
return func(proto, addr string) (net.PacketConn, error) {
switch proto {
case "udp":
t, err := net.ListenPacket("udp", "0.0.0.0:0")
if err != nil {
return nil, err
}
return t, nil
case "tcp":
c, err := net.Dial("tcp", addr)
if err != nil {
return nil, err
}
return NewSTUNConn(c), nil
case "tls":
c, err := tls.Dial("tcp", addr, &tls.Config{
MinVersion: tls.VersionTLS10,
InsecureSkipVerify: insecure, //nolint:gosec
})
if err != nil {
return nil, err
}
return NewSTUNConn(c), nil
case "dtls":
server, err := net.ResolveUDPAddr("udp", addr)
if err != nil {
return nil, err
}

conn, err := dtls.Dial("udp", server, &dtls.Config{
InsecureSkipVerify: insecure,
})
if err != nil {
return nil, err
}
return NewSTUNConn(conn), err
default:
return nil, fmt.Errorf("%w: invalid protocol", errProxyConnFail)
}
}
}

// ProxyConfig configures the Pion TURN proxy.
type ProxyConfig struct {
// Listeners is a list of client listeners.
Expand All @@ -33,12 +80,12 @@ type ProxyConfig struct {
// Address:port for the peer to access.
PeerAddr string

// Callback for generating PacketConns that can be used by the proxy to connect to the TURN server.
RelayConnGen RelayConnGen

// AuthGen is a callback used to generate TURN authentication credentials.
AuthGen AuthGen

// InsecureMode controls whether self-signed TLS certificates are accepted from the server.
InsecureMode bool

// LoggerFactory must be set for logging from this proxy.
LoggerFactory logging.LoggerFactory

Expand All @@ -58,8 +105,8 @@ func (c *ProxyConfig) validate() error {
return fmt.Errorf("%w: invalid peer", errInvalidProxyConfig)
}

if c.AuthGen == nil {
return fmt.Errorf("%w: invalid credential generator", errInvalidProxyConfig)
if c.AuthGen == nil || c.RelayConnGen == nil {
return fmt.Errorf("%w: invalid auth or relay-conn generator", errInvalidProxyConfig)
}

return nil
Expand All @@ -78,8 +125,8 @@ type Proxy struct {
peerAddr net.Addr
connTrack map[string]*connection // Conntrack table.
lock *sync.Mutex // Sync access to the conntrack state.
relayConnGen RelayConnGen
authGen AuthGen
insecure bool
loggerFactory logging.LoggerFactory
log logging.LeveledLogger
cancel context.CancelFunc
Expand Down Expand Up @@ -115,8 +162,8 @@ func NewProxy(config ProxyConfig) (*Proxy, error) {
peerAddr: peer,
connTrack: make(map[string]*connection),
lock: new(sync.Mutex),
relayConnGen: config.RelayConnGen,
authGen: config.AuthGen,
insecure: config.InsecureMode,
loggerFactory: loggerFactory,
log: loggerFactory.NewLogger("proxy"),
net: config.Net,
Expand Down Expand Up @@ -197,46 +244,10 @@ func (p *Proxy) allocate(conn net.Conn) (*connection, error) {
}
}

var turnConn net.PacketConn
switch proto {
case "udp":
t, errListen := p.net.ListenPacket("udp", "0.0.0.0:0")
if errListen != nil {
return nil, errListen
}
turnConn = t
case "tcp":
c, errListen := net.Dial("tcp", server)
if errListen != nil {
return nil, errListen
}
turnConn = NewSTUNConn(c)
case "tls":
c, errListen := tls.Dial("tcp", server, &tls.Config{
MinVersion: tls.VersionTLS10,
InsecureSkipVerify: p.insecure, //nolint:gosec
})
if errListen != nil {
return nil, errListen
}
turnConn = NewSTUNConn(c)
case "dtls":
addr, errListen := net.ResolveUDPAddr("udp", server)
if errListen != nil {
return nil, errListen
}

conn, errListen := dtls.Dial("udp", addr, &dtls.Config{
InsecureSkipVerify: p.insecure,
})
if errListen != nil {
return nil, errListen
}
turnConn = NewSTUNConn(conn)
default:
return nil, fmt.Errorf("%w: invalid protocol", errProxyConnFail)
turnConn, err := p.relayConnGen(proto, server)
if err != nil {
return nil, err
}

client.conn = turnConn

turnClient, err := NewClient(&ClientConfig{
Expand Down
1 change: 1 addition & 0 deletions proxy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ func testProxyTransport(t *testing.T, c testCase) {
TURNServerURI: c.uri,
Listeners: []net.Listener{c.listener},
PeerAddr: "127.0.0.1:5001",
RelayConnGen: DefaultRelayConnGen(true),
AuthGen: func() (string, string, error) { return "user1", "pass1", nil },
LoggerFactory: c.loggerFactory,
})
Expand Down

0 comments on commit c8f02cc

Please sign in to comment.