Skip to content

Commit 3ad0b7a

Browse files
committed
proxy: Add UDP support to SOCKS5 dialer
Dial("udp", address) should open a UDP connection with Read()/Write() methods supporting packet encapsulation as described in RFC 1928.
1 parent 88194ad commit 3ad0b7a

File tree

8 files changed

+347
-92
lines changed

8 files changed

+347
-92
lines changed

internal/socks/client.go

+42-30
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,18 @@ var (
1818
aLongTimeAgo = time.Unix(1, 0)
1919
)
2020

21-
func (d *Dialer) connect(ctx context.Context, c net.Conn, address string) (_ net.Addr, ctxErr error) {
22-
host, port, err := splitHostPort(address)
21+
func (d *Dialer) connect(ctx context.Context, c net.Conn, req Request) (conn net.Conn, _ net.Addr, ctxErr error) {
22+
var udpHeader []byte
23+
24+
host, port, err := splitHostPort(req.DstAddress)
2325
if err != nil {
24-
return nil, err
26+
return c, nil, err
2527
}
2628
if deadline, ok := ctx.Deadline(); ok && !deadline.IsZero() {
2729
c.SetDeadline(deadline)
28-
defer c.SetDeadline(noDeadline)
30+
if req.Cmd != CmdUDPAssociate {
31+
defer c.SetDeadline(noDeadline)
32+
}
2933
}
3034
if ctx != context.Background() {
3135
errCh := make(chan error, 1)
@@ -47,14 +51,15 @@ func (d *Dialer) connect(ctx context.Context, c net.Conn, address string) (_ net
4751
}()
4852
}
4953

54+
conn = c
5055
b := make([]byte, 0, 6+len(host)) // the size here is just an estimate
5156
b = append(b, Version5)
5257
if len(d.AuthMethods) == 0 || d.Authenticate == nil {
5358
b = append(b, 1, byte(AuthMethodNotRequired))
5459
} else {
5560
ams := d.AuthMethods
5661
if len(ams) > 255 {
57-
return nil, errors.New("too many authentication methods")
62+
return c, nil, errors.New("too many authentication methods")
5863
}
5964
b = append(b, byte(len(ams)))
6065
for _, am := range ams {
@@ -69,11 +74,11 @@ func (d *Dialer) connect(ctx context.Context, c net.Conn, address string) (_ net
6974
return
7075
}
7176
if b[0] != Version5 {
72-
return nil, errors.New("unexpected protocol version " + strconv.Itoa(int(b[0])))
77+
return c, nil, errors.New("unexpected protocol version " + strconv.Itoa(int(b[0])))
7378
}
7479
am := AuthMethod(b[1])
7580
if am == AuthMethodNoAcceptableMethods {
76-
return nil, errors.New("no acceptable authentication methods")
81+
return c, nil, errors.New("no acceptable authentication methods")
7782
}
7883
if d.Authenticate != nil {
7984
if ctxErr = d.Authenticate(ctx, c, am); ctxErr != nil {
@@ -82,7 +87,7 @@ func (d *Dialer) connect(ctx context.Context, c net.Conn, address string) (_ net
8287
}
8388

8489
b = b[:0]
85-
b = append(b, Version5, byte(d.cmd), 0)
90+
b = append(b, Version5, byte(req.Cmd), 0)
8691
if ip := net.ParseIP(host); ip != nil {
8792
if ip4 := ip.To4(); ip4 != nil {
8893
b = append(b, AddrTypeIPv4)
@@ -91,17 +96,23 @@ func (d *Dialer) connect(ctx context.Context, c net.Conn, address string) (_ net
9196
b = append(b, AddrTypeIPv6)
9297
b = append(b, ip6...)
9398
} else {
94-
return nil, errors.New("unknown address type")
99+
return c, nil, errors.New("unknown address type")
95100
}
96101
} else {
97102
if len(host) > 255 {
98-
return nil, errors.New("FQDN too long")
103+
return c, nil, errors.New("FQDN too long")
99104
}
100105
b = append(b, AddrTypeFQDN)
101106
b = append(b, byte(len(host)))
102107
b = append(b, host...)
103108
}
104109
b = append(b, byte(port>>8), byte(port))
110+
111+
if req.Cmd == CmdUDPAssociate {
112+
udpHeader = make([]byte, len(b))
113+
copy(udpHeader[3:], b[3:])
114+
}
115+
105116
if _, ctxErr = c.Write(b); ctxErr != nil {
106117
return
107118
}
@@ -110,17 +121,18 @@ func (d *Dialer) connect(ctx context.Context, c net.Conn, address string) (_ net
110121
return
111122
}
112123
if b[0] != Version5 {
113-
return nil, errors.New("unexpected protocol version " + strconv.Itoa(int(b[0])))
124+
return c, nil, errors.New("unexpected protocol version " + strconv.Itoa(int(b[0])))
114125
}
115126
if cmdErr := Reply(b[1]); cmdErr != StatusSucceeded {
116-
return nil, errors.New("unknown error " + cmdErr.String())
127+
return c, nil, errors.New("unknown error " + cmdErr.String())
117128
}
118129
if b[2] != 0 {
119-
return nil, errors.New("non-zero reserved field")
130+
return c, nil, errors.New("non-zero reserved field")
120131
}
121132
l := 2
133+
addrType := b[3]
122134
var a Addr
123-
switch b[3] {
135+
switch addrType {
124136
case AddrTypeIPv4:
125137
l += net.IPv4len
126138
a.IP = make(net.IP, net.IPv4len)
@@ -129,12 +141,13 @@ func (d *Dialer) connect(ctx context.Context, c net.Conn, address string) (_ net
129141
a.IP = make(net.IP, net.IPv6len)
130142
case AddrTypeFQDN:
131143
if _, err := io.ReadFull(c, b[:1]); err != nil {
132-
return nil, err
144+
return c, nil, err
133145
}
134146
l += int(b[0])
135147
default:
136-
return nil, errors.New("unknown address type " + strconv.Itoa(int(b[3])))
148+
return c, nil, errors.New("unknown address type " + strconv.Itoa(int(b[3])))
137149
}
150+
138151
if cap(b) < l {
139152
b = make([]byte, l)
140153
} else {
@@ -149,20 +162,19 @@ func (d *Dialer) connect(ctx context.Context, c net.Conn, address string) (_ net
149162
a.Name = string(b[:len(b)-2])
150163
}
151164
a.Port = int(b[len(b)-2])<<8 | int(b[len(b)-1])
152-
return &a, nil
153-
}
154165

155-
func splitHostPort(address string) (string, int, error) {
156-
host, port, err := net.SplitHostPort(address)
157-
if err != nil {
158-
return "", 0, err
159-
}
160-
portnum, err := strconv.Atoi(port)
161-
if err != nil {
162-
return "", 0, err
163-
}
164-
if 1 > portnum || portnum > 0xffff {
165-
return "", 0, errors.New("port number out of range " + port)
166+
if req.Cmd == CmdUDPAssociate {
167+
var uc net.Conn
168+
if uc, err = d.proxyDial(ctx, req.UDPNetwork, a.String()); err != nil {
169+
return c, &a, err
170+
}
171+
c.SetDeadline(noDeadline)
172+
go func() {
173+
defer uc.Close()
174+
io.Copy(io.Discard, c)
175+
}()
176+
return udpConn{Conn: uc, socksConn: c, header: udpHeader}, &a, nil
166177
}
167-
return host, portnum, nil
178+
179+
return c, &a, nil
168180
}

internal/socks/dial_test.go

+111-5
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ package socks_test
66

77
import (
88
"context"
9+
"errors"
910
"io"
1011
"math/rand"
1112
"net"
@@ -15,6 +16,7 @@ import (
1516

1617
"golang.org/x/net/internal/socks"
1718
"golang.org/x/net/internal/sockstest"
19+
"golang.org/x/net/nettest"
1820
)
1921

2022
func TestDial(t *testing.T) {
@@ -33,7 +35,7 @@ func TestDial(t *testing.T) {
3335
Username: "username",
3436
Password: "password",
3537
}).Authenticate
36-
c, err := d.DialContext(context.Background(), ss.TargetAddr().Network(), ss.TargetAddr().String())
38+
c, err := d.DialContext(context.Background(), "tcp", ss.TargetAddrPort().String())
3739
if err != nil {
3840
t.Fatal(err)
3941
}
@@ -60,7 +62,7 @@ func TestDial(t *testing.T) {
6062
Username: "username",
6163
Password: "password",
6264
}).Authenticate
63-
a, err := d.DialWithConn(context.Background(), c, ss.TargetAddr().Network(), ss.TargetAddr().String())
65+
a, err := d.DialWithConn(context.Background(), c, "tcp", ss.TargetAddrPort().String())
6466
if err != nil {
6567
t.Fatal(err)
6668
}
@@ -79,7 +81,7 @@ func TestDial(t *testing.T) {
7981
defer cancel()
8082
dialErr := make(chan error)
8183
go func() {
82-
c, err := d.DialContext(ctx, ss.TargetAddr().Network(), ss.TargetAddr().String())
84+
c, err := d.DialContext(ctx, "tcp", ss.TargetAddrPort().String())
8385
if err == nil {
8486
c.Close()
8587
}
@@ -101,7 +103,7 @@ func TestDial(t *testing.T) {
101103
d := socks.NewDialer(ss.Addr().Network(), ss.Addr().String())
102104
ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(100*time.Millisecond))
103105
defer cancel()
104-
c, err := d.DialContext(ctx, ss.TargetAddr().Network(), ss.TargetAddr().String())
106+
c, err := d.DialContext(ctx, "tcp", ss.TargetAddrPort().String())
105107
if err == nil {
106108
c.Close()
107109
}
@@ -119,14 +121,88 @@ func TestDial(t *testing.T) {
119121
for i := 0; i < 2*len(rogueCmdList); i++ {
120122
ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(100*time.Millisecond))
121123
defer cancel()
122-
c, err := d.DialContext(ctx, ss.TargetAddr().Network(), ss.TargetAddr().String())
124+
c, err := d.DialContext(ctx, "tcp", ss.TargetAddrPort().String())
123125
if err == nil {
124126
t.Log(c.(*socks.Conn).BoundAddr())
125127
c.Close()
126128
t.Error("should fail")
127129
}
128130
}
129131
})
132+
t.Run("UDPAssociate", func(t *testing.T) {
133+
ss, err := sockstest.NewServer(sockstest.NoAuthRequired, sockstest.NoProxyRequired)
134+
if err != nil {
135+
t.Fatal(err)
136+
}
137+
defer ss.Close()
138+
d := socks.NewDialer(ss.Addr().Network(), ss.Addr().String())
139+
c, err := d.DialContext(context.Background(), "udp", ss.TargetAddrPort().String())
140+
if err != nil {
141+
t.Fatal(err)
142+
}
143+
c.Close()
144+
if network := c.RemoteAddr().Network(); network != "udp" {
145+
t.Errorf("RemoteAddr().Network(): expected \"udp\" got %q", network)
146+
}
147+
expected := "127.0.0.1:5964"
148+
if remoteAddr := c.RemoteAddr().String(); remoteAddr != expected {
149+
t.Errorf("RemoteAddr(): expected %q got %q", expected, remoteAddr)
150+
}
151+
if boundAddr := c.(*socks.Conn).BoundAddr().String(); boundAddr != expected {
152+
t.Errorf("BoundAddr(): expected %q got %q", expected, boundAddr)
153+
}
154+
})
155+
t.Run("UDPAssociateWithReadAndWrite", func(t *testing.T) {
156+
rc, cmdFunc, err := packetListenerCmdFunc()
157+
if err != nil {
158+
t.Fatal(err)
159+
}
160+
defer rc.Close()
161+
ss, err := sockstest.NewServer(sockstest.NoAuthRequired, cmdFunc)
162+
if err != nil {
163+
t.Fatal(err)
164+
}
165+
defer ss.Close()
166+
d := socks.NewDialer(ss.Addr().Network(), ss.Addr().String())
167+
c, err := d.DialContext(context.Background(), "udp", ss.TargetAddrPort().String())
168+
if err != nil {
169+
t.Fatal(err)
170+
}
171+
defer c.Close()
172+
buf := make([]byte, 32)
173+
expected := "HELLO OUTBOUND"
174+
n, err := c.Write([]byte(expected))
175+
if err != nil {
176+
t.Fatal(err)
177+
}
178+
if len(expected) != n {
179+
t.Errorf("Write(): expected %v bytes got %v", len(expected), n)
180+
}
181+
n, addr, err := rc.ReadFrom(buf)
182+
if err != nil {
183+
t.Fatal(err)
184+
}
185+
data, err := socks.SkipUDPHeader(buf[:n])
186+
if err != nil {
187+
t.Fatal(err)
188+
}
189+
if actual := string(data); expected != actual {
190+
t.Errorf("ReadFrom(): expected %q got %q", expected, actual)
191+
}
192+
udpHeader := []byte{0x00, 0x00, 0x00, 0x01, 0x7f, 0x00, 0x00, 0x01, 0x17, 0x4b}
193+
expected = "HELLO INBOUND"
194+
_, err = rc.WriteTo(append(udpHeader, []byte(expected)...), addr)
195+
if err != nil {
196+
t.Fatal(err)
197+
}
198+
n, err = c.Read(buf)
199+
if err != nil {
200+
t.Fatal(err)
201+
}
202+
if actual := string(buf[:n]); expected != actual {
203+
t.Errorf("Read(): expected %q got %q", expected, actual)
204+
}
205+
})
130206
}
131207

132208
func blackholeCmdFunc(rw io.ReadWriter, b []byte) error {
@@ -168,3 +244,33 @@ func parseDialError(err error) (perr, nerr error) {
168244
perr = err
169245
return
170246
}
247+
248+
func packetListenerCmdFunc() (net.PacketConn, func(io.ReadWriter, []byte) error, error) {
249+
conn, err := nettest.NewLocalPacketListener("udp")
250+
if err != nil {
251+
return nil, nil, err
252+
}
253+
localAddr := conn.LocalAddr().(*net.UDPAddr)
254+
return conn, func(rw io.ReadWriter, b []byte) error {
255+
req, err := sockstest.ParseCmdRequest(b)
256+
if err != nil {
257+
return err
258+
}
259+
if req.Cmd != socks.CmdUDPAssociate {
260+
return errors.New("unexpected command")
261+
}
262+
b, err = sockstest.MarshalCmdReply(socks.Version5, socks.StatusSucceeded, &socks.Addr{IP: localAddr.IP, Port: localAddr.Port})
263+
if err != nil {
264+
return err
265+
}
266+
n, err := rw.Write(b)
267+
if err != nil {
268+
return err
269+
}
270+
if n != len(b) {
271+
return errors.New("short write")
272+
}
273+
_, err = io.Copy(io.Discard, rw)
274+
return err
275+
}, nil
276+
}

0 commit comments

Comments
 (0)