Skip to content

Commit

Permalink
Add UDP support to SOCKS5 dialer
Browse files Browse the repository at this point in the history
  • Loading branch information
marshall-lee committed Oct 6, 2023
1 parent 88194ad commit d431cc5
Show file tree
Hide file tree
Showing 8 changed files with 347 additions and 92 deletions.
72 changes: 42 additions & 30 deletions internal/socks/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,18 @@ var (
aLongTimeAgo = time.Unix(1, 0)
)

func (d *Dialer) connect(ctx context.Context, c net.Conn, address string) (_ net.Addr, ctxErr error) {
host, port, err := splitHostPort(address)
func (d *Dialer) connect(ctx context.Context, c net.Conn, req Request) (conn net.Conn, _ net.Addr, ctxErr error) {
var udpHeader []byte

host, port, err := splitHostPort(req.DstAddress)
if err != nil {
return nil, err
return c, nil, err
}
if deadline, ok := ctx.Deadline(); ok && !deadline.IsZero() {
c.SetDeadline(deadline)
defer c.SetDeadline(noDeadline)
if req.Cmd != CmdUDPAssociate {
defer c.SetDeadline(noDeadline)
}
}
if ctx != context.Background() {
errCh := make(chan error, 1)
Expand All @@ -47,14 +51,15 @@ func (d *Dialer) connect(ctx context.Context, c net.Conn, address string) (_ net
}()
}

conn = c
b := make([]byte, 0, 6+len(host)) // the size here is just an estimate
b = append(b, Version5)
if len(d.AuthMethods) == 0 || d.Authenticate == nil {
b = append(b, 1, byte(AuthMethodNotRequired))
} else {
ams := d.AuthMethods
if len(ams) > 255 {
return nil, errors.New("too many authentication methods")
return c, nil, errors.New("too many authentication methods")
}
b = append(b, byte(len(ams)))
for _, am := range ams {
Expand All @@ -69,11 +74,11 @@ func (d *Dialer) connect(ctx context.Context, c net.Conn, address string) (_ net
return
}
if b[0] != Version5 {
return nil, errors.New("unexpected protocol version " + strconv.Itoa(int(b[0])))
return c, nil, errors.New("unexpected protocol version " + strconv.Itoa(int(b[0])))
}
am := AuthMethod(b[1])
if am == AuthMethodNoAcceptableMethods {
return nil, errors.New("no acceptable authentication methods")
return c, nil, errors.New("no acceptable authentication methods")
}
if d.Authenticate != nil {
if ctxErr = d.Authenticate(ctx, c, am); ctxErr != nil {
Expand All @@ -82,7 +87,7 @@ func (d *Dialer) connect(ctx context.Context, c net.Conn, address string) (_ net
}

b = b[:0]
b = append(b, Version5, byte(d.cmd), 0)
b = append(b, Version5, byte(req.Cmd), 0)
if ip := net.ParseIP(host); ip != nil {
if ip4 := ip.To4(); ip4 != nil {
b = append(b, AddrTypeIPv4)
Expand All @@ -91,17 +96,23 @@ func (d *Dialer) connect(ctx context.Context, c net.Conn, address string) (_ net
b = append(b, AddrTypeIPv6)
b = append(b, ip6...)
} else {
return nil, errors.New("unknown address type")
return c, nil, errors.New("unknown address type")
}
} else {
if len(host) > 255 {
return nil, errors.New("FQDN too long")
return c, nil, errors.New("FQDN too long")
}
b = append(b, AddrTypeFQDN)
b = append(b, byte(len(host)))
b = append(b, host...)
}
b = append(b, byte(port>>8), byte(port))

if req.Cmd == CmdUDPAssociate {
udpHeader = make([]byte, len(b))
copy(udpHeader[3:], b[3:])
}

if _, ctxErr = c.Write(b); ctxErr != nil {
return
}
Expand All @@ -110,17 +121,18 @@ func (d *Dialer) connect(ctx context.Context, c net.Conn, address string) (_ net
return
}
if b[0] != Version5 {
return nil, errors.New("unexpected protocol version " + strconv.Itoa(int(b[0])))
return c, nil, errors.New("unexpected protocol version " + strconv.Itoa(int(b[0])))
}
if cmdErr := Reply(b[1]); cmdErr != StatusSucceeded {
return nil, errors.New("unknown error " + cmdErr.String())
return c, nil, errors.New("unknown error " + cmdErr.String())
}
if b[2] != 0 {
return nil, errors.New("non-zero reserved field")
return c, nil, errors.New("non-zero reserved field")
}
l := 2
addrType := b[3]
var a Addr
switch b[3] {
switch addrType {
case AddrTypeIPv4:
l += net.IPv4len
a.IP = make(net.IP, net.IPv4len)
Expand All @@ -129,12 +141,13 @@ func (d *Dialer) connect(ctx context.Context, c net.Conn, address string) (_ net
a.IP = make(net.IP, net.IPv6len)
case AddrTypeFQDN:
if _, err := io.ReadFull(c, b[:1]); err != nil {
return nil, err
return c, nil, err
}
l += int(b[0])
default:
return nil, errors.New("unknown address type " + strconv.Itoa(int(b[3])))
return c, nil, errors.New("unknown address type " + strconv.Itoa(int(b[3])))
}

if cap(b) < l {
b = make([]byte, l)
} else {
Expand All @@ -149,20 +162,19 @@ func (d *Dialer) connect(ctx context.Context, c net.Conn, address string) (_ net
a.Name = string(b[:len(b)-2])
}
a.Port = int(b[len(b)-2])<<8 | int(b[len(b)-1])
return &a, nil
}

func splitHostPort(address string) (string, int, error) {
host, port, err := net.SplitHostPort(address)
if err != nil {
return "", 0, err
}
portnum, err := strconv.Atoi(port)
if err != nil {
return "", 0, err
}
if 1 > portnum || portnum > 0xffff {
return "", 0, errors.New("port number out of range " + port)
if req.Cmd == CmdUDPAssociate {
var uc net.Conn
if uc, err = d.proxyDial(ctx, req.UDPNetwork, a.String()); err != nil {
return c, &a, err
}
c.SetDeadline(noDeadline)
go func() {
defer uc.Close()
io.Copy(io.Discard, c)
}()
return udpConn{Conn: uc, socksConn: c, header: udpHeader}, &a, nil
}
return host, portnum, nil

return c, &a, nil
}
116 changes: 111 additions & 5 deletions internal/socks/dial_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ package socks_test

import (
"context"
"errors"
"io"
"math/rand"
"net"
Expand All @@ -15,6 +16,7 @@ import (

"golang.org/x/net/internal/socks"
"golang.org/x/net/internal/sockstest"
"golang.org/x/net/nettest"
)

func TestDial(t *testing.T) {
Expand All @@ -33,7 +35,7 @@ func TestDial(t *testing.T) {
Username: "username",
Password: "password",
}).Authenticate
c, err := d.DialContext(context.Background(), ss.TargetAddr().Network(), ss.TargetAddr().String())
c, err := d.DialContext(context.Background(), "tcp", ss.TargetAddrPort().String())
if err != nil {
t.Fatal(err)
}
Expand All @@ -60,7 +62,7 @@ func TestDial(t *testing.T) {
Username: "username",
Password: "password",
}).Authenticate
a, err := d.DialWithConn(context.Background(), c, ss.TargetAddr().Network(), ss.TargetAddr().String())
a, err := d.DialWithConn(context.Background(), c, "tcp", ss.TargetAddrPort().String())
if err != nil {
t.Fatal(err)
}
Expand All @@ -79,7 +81,7 @@ func TestDial(t *testing.T) {
defer cancel()
dialErr := make(chan error)
go func() {
c, err := d.DialContext(ctx, ss.TargetAddr().Network(), ss.TargetAddr().String())
c, err := d.DialContext(ctx, "tcp", ss.TargetAddrPort().String())
if err == nil {
c.Close()
}
Expand All @@ -101,7 +103,7 @@ func TestDial(t *testing.T) {
d := socks.NewDialer(ss.Addr().Network(), ss.Addr().String())
ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(100*time.Millisecond))
defer cancel()
c, err := d.DialContext(ctx, ss.TargetAddr().Network(), ss.TargetAddr().String())
c, err := d.DialContext(ctx, "tcp", ss.TargetAddrPort().String())
if err == nil {
c.Close()
}
Expand All @@ -119,14 +121,88 @@ func TestDial(t *testing.T) {
for i := 0; i < 2*len(rogueCmdList); i++ {
ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(100*time.Millisecond))
defer cancel()
c, err := d.DialContext(ctx, ss.TargetAddr().Network(), ss.TargetAddr().String())
c, err := d.DialContext(ctx, "tcp", ss.TargetAddrPort().String())
if err == nil {
t.Log(c.(*socks.Conn).BoundAddr())
c.Close()
t.Error("should fail")
}
}
})
t.Run("UDPAssociate", func(t *testing.T) {
ss, err := sockstest.NewServer(sockstest.NoAuthRequired, sockstest.NoProxyRequired)
if err != nil {
t.Fatal(err)
}
defer ss.Close()
d := socks.NewDialer(ss.Addr().Network(), ss.Addr().String())
c, err := d.DialContext(context.Background(), "udp", ss.TargetAddrPort().String())
if err != nil {
t.Fatal(err)
}
c.Close()
if network := c.RemoteAddr().Network(); network != "udp" {
t.Errorf("RemoteAddr().Network(): expected \"udp\" got %q", network)
}
expected := "127.0.0.1:5964"
if remoteAddr := c.RemoteAddr().String(); remoteAddr != expected {
t.Errorf("RemoteAddr(): expected %q got %q", expected, remoteAddr)
}
if boundAddr := c.(*socks.Conn).BoundAddr().String(); boundAddr != expected {
t.Errorf("BoundAddr(): expected %q got %q", expected, boundAddr)
}
})
t.Run("UDPAssociateWithReadAndWrite", func(t *testing.T) {
rc, cmdFunc, err := packetListenerCmdFunc()
if err != nil {
t.Fatal(err)
}
defer rc.Close()
ss, err := sockstest.NewServer(sockstest.NoAuthRequired, cmdFunc)
if err != nil {
t.Fatal(err)
}
defer ss.Close()
d := socks.NewDialer(ss.Addr().Network(), ss.Addr().String())
c, err := d.DialContext(context.Background(), "udp", ss.TargetAddrPort().String())
if err != nil {
t.Fatal(err)
}
defer c.Close()
buf := make([]byte, 32)
expected := "HELLO OUTBOUND"
n, err := c.Write([]byte(expected))
if err != nil {
t.Fatal(err)
}
if len(expected) != n {
t.Errorf("Write(): expected %v bytes got %v", len(expected), n)
}
n, addr, err := rc.ReadFrom(buf)
if err != nil {
t.Fatal(err)
}
data, err := socks.SkipUDPHeader(buf[:n])
if err != nil {
t.Fatal(err)
}
if actual := string(data); expected != actual {
t.Errorf("ReadFrom(): expected %q got %q", expected, actual)
}
udpHeader := []byte{0x00, 0x00, 0x00, 0x01, 0x7f, 0x00, 0x00, 0x01, 0x17, 0x4b}
expected = "HELLO INBOUND"
_, err = rc.WriteTo(append(udpHeader, []byte(expected)...), addr)
if err != nil {
t.Fatal(err)
}
n, err = c.Read(buf)
if err != nil {
t.Fatal(err)
}
if actual := string(buf[:n]); expected != actual {
t.Errorf("Read(): expected %q got %q", expected, actual)
}
})
}

func blackholeCmdFunc(rw io.ReadWriter, b []byte) error {
Expand Down Expand Up @@ -168,3 +244,33 @@ func parseDialError(err error) (perr, nerr error) {
perr = err
return
}

func packetListenerCmdFunc() (net.PacketConn, func(io.ReadWriter, []byte) error, error) {
conn, err := nettest.NewLocalPacketListener("udp")
if err != nil {
return nil, nil, err
}
localAddr := conn.LocalAddr().(*net.UDPAddr)
return conn, func(rw io.ReadWriter, b []byte) error {
req, err := sockstest.ParseCmdRequest(b)
if err != nil {
return err
}
if req.Cmd != socks.CmdUDPAssociate {
return errors.New("unexpected command")
}
b, err = sockstest.MarshalCmdReply(socks.Version5, socks.StatusSucceeded, &socks.Addr{IP: localAddr.IP, Port: localAddr.Port})
if err != nil {
return err
}
n, err := rw.Write(b)
if err != nil {
return err
}
if n != len(b) {
return errors.New("short write")
}
_, err = io.Copy(io.Discard, rw)
return err
}, nil
}
Loading

0 comments on commit d431cc5

Please sign in to comment.