diff --git a/go.mod b/go.mod index 8a0df36..efff385 100644 --- a/go.mod +++ b/go.mod @@ -6,11 +6,7 @@ require ( github.com/BurntSushi/toml v1.4.0 github.com/kevinburke/ssh_config v1.2.0 golang.org/x/crypto v0.27.0 + golang.org/x/sync v0.8.0 ) -require ( - github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5 // indirect - golang.org/x/net v0.21.0 // indirect - golang.org/x/sync v0.8.0 // indirect - golang.org/x/sys v0.25.0 // indirect -) +require golang.org/x/sys v0.25.0 // indirect diff --git a/go.sum b/go.sum index 600c472..31f7430 100644 --- a/go.sum +++ b/go.sum @@ -1,13 +1,9 @@ github.com/BurntSushi/toml v1.4.0 h1:kuoIxZQy2WRRk1pttg9asf+WVv6tWQuBNVmK8+nqPr0= github.com/BurntSushi/toml v1.4.0/go.mod h1:ukJfTF/6rtPPRCnwkur4qwRxa8vTRFBF0uk2lLoLwho= -github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5 h1:0CwZNZbxp69SHPdPJAN/hZIm0C4OItdklCFmMRWYpio= -github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5/go.mod h1:wHh0iHkYZB8zMSxRWpUBQtwG5a7fFgvEO+odwuTv2gs= github.com/kevinburke/ssh_config v1.2.0 h1:x584FjTGwHzMwvHx18PXxbBVzfnxogHaAReU4gf13a4= github.com/kevinburke/ssh_config v1.2.0/go.mod h1:CT57kijsi8u/K/BOFA39wgDQJ9CxiF4nAY/ojJ6r6mM= golang.org/x/crypto v0.27.0 h1:GXm2NjJrPaiv/h1tb2UH8QfgC/hOf/+z0p6PT8o1w7A= golang.org/x/crypto v0.27.0/go.mod h1:1Xngt8kV6Dvbssa53Ziq6Eqn0HqbZi5Z6R0ZpwQzt70= -golang.org/x/net v0.21.0 h1:AQyQV4dYCvJ7vGmJyKki9+PBdyvhkSd8EIx/qb0AYv4= -golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44= golang.org/x/sync v0.8.0 h1:3NFvSEYkUoMifnESzZl15y791HH1qU2xm6eCJU5ZPXQ= golang.org/x/sync v0.8.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sys v0.25.0 h1:r+8e+loiHxRqhXVl6ML1nO3l1+oFoWbnlu2Ehimmi34= diff --git a/internal/proxy/socks5.go b/internal/proxy/socks5.go new file mode 100644 index 0000000..7cf3e8b --- /dev/null +++ b/internal/proxy/socks5.go @@ -0,0 +1,710 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package socks5 is a SOCKS5 server implementation. +// +// This is used for userspace networking in Tailscale. Specifically, +// this is used for dialing out of the machine to other nodes, without +// the host kernel's involvement, so it doesn't proper routing tables, +// TUN, IPv6, etc. This package is meant to only handle the SOCKS5 protocol +// details and not any integration with Tailscale internals itself. +// +// The glue between this package and Tailscale is in net/socks5/tssocks. + +package proxy + +import ( + "bytes" + "context" + "encoding/binary" + "errors" + "fmt" + "io" + "log" + "net" + "strconv" + "time" +) + +// Authentication METHODs described in RFC 1928, section 3. +const ( + noAuthRequired byte = 0 + passwordAuth byte = 2 + noAcceptableAuth byte = 255 +) + +// passwordAuthVersion is the auth version byte described in RFC 1929. +const passwordAuthVersion = 1 + +// socks5Version is the byte that represents the SOCKS version +// in requests. +const socks5Version byte = 5 + +// commandType are the bytes sent in SOCKS5 packets +// that represent the kind of connection the client needs. +type commandType byte + +// The set of valid SOCKS5 commands as described in RFC 1928. +const ( + connect commandType = 1 + bind commandType = 2 + udpAssociate commandType = 3 +) + +// addrType are the bytes sent in SOCKS5 packets +// that represent particular address types. +type addrType byte + +// The set of valid SOCKS5 address types as defined in RFC 1928. +const ( + ipv4 addrType = 1 + domainName addrType = 3 + ipv6 addrType = 4 +) + +// replyCode are the bytes sent in SOCKS5 packets +// that represent replies from the server to a client +// request. +type replyCode byte + +// The set of valid SOCKS5 reply types as per the RFC 1928. +const ( + success replyCode = 0 + generalFailure replyCode = 1 + connectionNotAllowed replyCode = 2 + networkUnreachable replyCode = 3 + hostUnreachable replyCode = 4 + connectionRefused replyCode = 5 + ttlExpired replyCode = 6 + commandNotSupported replyCode = 7 + addrTypeNotSupported replyCode = 8 +) + +type logf func(format string, args... any) + +// Server is a SOCKS5 proxy server. +type Server struct { + // Logf optionally specifies the logger to use. + // If nil, the standard logger is used. + Logf logf + + // Dialer optionally specifies the dialer to use for outgoing connections. + // If nil, the net package's standard dialer is used. + Dialer func(ctx context.Context, network, addr string) (net.Conn, error) + + // Username and Password, if set, are the credential clients must provide. + Username string + Password string +} + +func (s *Server) dial(ctx context.Context, network, addr string) (net.Conn, error) { + dial := s.Dialer + if dial == nil { + dialer := &net.Dialer{} + dial = dialer.DialContext + } + return dial(ctx, network, addr) +} + +func (s *Server) logf(format string, args ...any) { + logf := s.Logf + if logf == nil { + logf = log.Printf + } + logf(format, args...) +} + +// Serve accepts and handles incoming connections on the given listener. +func (s *Server) Serve(l net.Listener) error { + defer l.Close() + for { + conn, err := l.Accept() + if err != nil { + return err + } + go s.ServeConn(conn) + } +} + +func (s *Server) ServeConn(conn net.Conn) error { + defer conn.Close() + socksC := &Conn{logf: s.Logf, clientConn: conn, srv: s} + err := socksC.Run() + if err != nil { + s.logf("client connection failed: %v", err) + } + return nil +} + +// Conn is a SOCKS5 connection for client to reach +// server. +type Conn struct { + // The struct is filled by each of the internal + // methods in turn as the transaction progresses. + logf logf + srv *Server + clientConn net.Conn + request *request + + udpClientAddr net.Addr +} + +// Run starts the new connection. +func (c *Conn) Run() error { + needAuth := c.srv.Username != "" || c.srv.Password != "" + authMethod := noAuthRequired + if needAuth { + authMethod = passwordAuth + } + + err := parseClientGreeting(c.clientConn, authMethod) + if err != nil { + c.clientConn.Write([]byte{socks5Version, noAcceptableAuth}) + return err + } + c.clientConn.Write([]byte{socks5Version, authMethod}) + if !needAuth { + return c.handleRequest() + } + + user, pwd, err := parseClientAuth(c.clientConn) + if err != nil || user != c.srv.Username || pwd != c.srv.Password { + c.clientConn.Write([]byte{1, 1}) // auth error + return err + } + c.clientConn.Write([]byte{1, 0}) // auth success + + return c.handleRequest() +} + +func (c *Conn) handleRequest() error { + req, err := parseClientRequest(c.clientConn) + if err != nil { + res := errorResponse(generalFailure) + buf, _ := res.marshal() + c.clientConn.Write(buf) + return err + } + + c.request = req + switch req.command { + case connect: + return c.handleTCP() + case udpAssociate: + return c.handleUDP() + default: + res := errorResponse(commandNotSupported) + buf, _ := res.marshal() + c.clientConn.Write(buf) + return fmt.Errorf("unsupported command %v", req.command) + } +} + +func (c *Conn) handleTCP() error { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + srv, err := c.srv.dial( + ctx, + "tcp", + c.request.destination.hostPort(), + ) + if err != nil { + res := errorResponse(generalFailure) + buf, _ := res.marshal() + c.clientConn.Write(buf) + return err + } + defer srv.Close() + + localAddr := srv.LocalAddr().String() + serverAddr, serverPort, err := splitHostPort(localAddr) + if err != nil { + return err + } + + res := &response{ + reply: success, + bindAddr: socksAddr{ + addrType: getAddrType(serverAddr), + addr: serverAddr, + port: serverPort, + }, + } + buf, err := res.marshal() + if err != nil { + res = errorResponse(generalFailure) + buf, _ = res.marshal() + } + c.clientConn.Write(buf) + + errc := make(chan error, 2) + go func() { + _, err := io.Copy(c.clientConn, srv) + if err != nil { + err = fmt.Errorf("from backend to client: %w", err) + } + errc <- err + }() + go func() { + _, err := io.Copy(srv, c.clientConn) + if err != nil { + err = fmt.Errorf("from client to backend: %w", err) + } + errc <- err + }() + return <-errc +} + +func (c *Conn) handleUDP() error { + // The DST.ADDR and DST.PORT fields contain the address and port that + // the client expects to use to send UDP datagrams on for the + // association. The server MAY use this information to limit access + // to the association. + // @see Page 6, https://datatracker.ietf.org/doc/html/rfc1928. + // + // We do NOT limit the access from the client currently in this implementation. + _ = c.request.destination + + addr := c.clientConn.LocalAddr() + host, _, err := net.SplitHostPort(addr.String()) + if err != nil { + return err + } + clientUDPConn, err := net.ListenPacket("udp", net.JoinHostPort(host, "0")) + if err != nil { + res := errorResponse(generalFailure) + buf, _ := res.marshal() + c.clientConn.Write(buf) + return err + } + defer clientUDPConn.Close() + + serverUDPConn, err := net.ListenPacket("udp", "[::]:0") + if err != nil { + res := errorResponse(generalFailure) + buf, _ := res.marshal() + c.clientConn.Write(buf) + return err + } + defer serverUDPConn.Close() + + bindAddr, bindPort, err := splitHostPort(clientUDPConn.LocalAddr().String()) + if err != nil { + return err + } + + res := &response{ + reply: success, + bindAddr: socksAddr{ + addrType: getAddrType(bindAddr), + addr: bindAddr, + port: bindPort, + }, + } + buf, err := res.marshal() + if err != nil { + res = errorResponse(generalFailure) + buf, _ = res.marshal() + } + c.clientConn.Write(buf) + + return c.transferUDP(c.clientConn, clientUDPConn, serverUDPConn) +} + +func (c *Conn) transferUDP(associatedTCP net.Conn, clientConn net.PacketConn, targetConn net.PacketConn) error { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + const bufferSize = 8 * 1024 + const readTimeout = 5 * time.Second + + // client -> target + go func() { + defer cancel() + buf := make([]byte, bufferSize) + for { + select { + case <-ctx.Done(): + return + default: + err := c.handleUDPRequest(clientConn, targetConn, buf, readTimeout) + if err != nil { + if isTimeout(err) { + continue + } + if errors.Is(err, net.ErrClosed) { + return + } + c.logf("udp transfer: handle udp request fail: %v", err) + } + } + } + }() + + // target -> client + go func() { + defer cancel() + buf := make([]byte, bufferSize) + for { + select { + case <-ctx.Done(): + return + default: + err := c.handleUDPResponse(targetConn, clientConn, buf, readTimeout) + if err != nil { + if isTimeout(err) { + continue + } + if errors.Is(err, net.ErrClosed) { + return + } + c.logf("udp transfer: handle udp response fail: %v", err) + } + } + } + }() + + // A UDP association terminates when the TCP connection that the UDP + // ASSOCIATE request arrived on terminates. RFC1928 + _, err := io.Copy(io.Discard, associatedTCP) + if err != nil { + err = fmt.Errorf("udp associated tcp conn: %w", err) + } + return err +} + +func (c *Conn) handleUDPRequest( + clientConn net.PacketConn, + targetConn net.PacketConn, + buf []byte, + readTimeout time.Duration, +) error { + // add a deadline for the read to avoid blocking forever + _ = clientConn.SetReadDeadline(time.Now().Add(readTimeout)) + n, addr, err := clientConn.ReadFrom(buf) + if err != nil { + return fmt.Errorf("read from client: %w", err) + } + c.udpClientAddr = addr + req, data, err := parseUDPRequest(buf[:n]) + if err != nil { + return fmt.Errorf("parse udp request: %w", err) + } + targetAddr, err := net.ResolveUDPAddr("udp", req.addr.hostPort()) + if err != nil { + c.logf("resolve target addr fail: %v", err) + } + + nn, err := targetConn.WriteTo(data, targetAddr) + if err != nil { + return fmt.Errorf("write to target %s fail: %w", targetAddr, err) + } + if nn != len(data) { + return fmt.Errorf("write to target %s fail: %w", targetAddr, io.ErrShortWrite) + } + return nil +} + +func (c *Conn) handleUDPResponse( + targetConn net.PacketConn, + clientConn net.PacketConn, + buf []byte, + readTimeout time.Duration, +) error { + // add a deadline for the read to avoid blocking forever + _ = targetConn.SetReadDeadline(time.Now().Add(readTimeout)) + n, addr, err := targetConn.ReadFrom(buf) + if err != nil { + return fmt.Errorf("read from target: %w", err) + } + host, port, err := splitHostPort(addr.String()) + if err != nil { + return fmt.Errorf("split host port: %w", err) + } + hdr := udpRequest{addr: socksAddr{addrType: getAddrType(host), addr: host, port: port}} + pkt, err := hdr.marshal() + if err != nil { + return fmt.Errorf("marshal udp request: %w", err) + } + data := append(pkt, buf[:n]...) + // use addr from client to send back + nn, err := clientConn.WriteTo(data, c.udpClientAddr) + if err != nil { + return fmt.Errorf("write to client: %w", err) + } + if nn != len(data) { + return fmt.Errorf("write to client: %w", io.ErrShortWrite) + } + return nil +} + +func isTimeout(err error) bool { + terr, ok := errors.Unwrap(err).(interface{ Timeout() bool }) + return ok && terr.Timeout() +} + +func splitHostPort(hostport string) (host string, port uint16, err error) { + host, portStr, err := net.SplitHostPort(hostport) + if err != nil { + return "", 0, err + } + portInt, err := strconv.Atoi(portStr) + if err != nil { + return "", 0, err + } + if portInt < 0 || portInt > 65535 { + return "", 0, fmt.Errorf("invalid port number %d", portInt) + } + return host, uint16(portInt), nil +} + +// parseClientGreeting parses a request initiation packet. +func parseClientGreeting(r io.Reader, authMethod byte) error { + var hdr [2]byte + _, err := io.ReadFull(r, hdr[:]) + if err != nil { + return fmt.Errorf("could not read packet header") + } + if hdr[0] != socks5Version { + return fmt.Errorf("incompatible SOCKS version") + } + count := int(hdr[1]) + methods := make([]byte, count) + _, err = io.ReadFull(r, methods) + if err != nil { + return fmt.Errorf("could not read methods") + } + for _, m := range methods { + if m == authMethod { + return nil + } + } + return fmt.Errorf("no acceptable auth methods") +} + +func parseClientAuth(r io.Reader) (usr, pwd string, err error) { + var hdr [2]byte + if _, err := io.ReadFull(r, hdr[:]); err != nil { + return "", "", fmt.Errorf("could not read auth packet header") + } + if hdr[0] != passwordAuthVersion { + return "", "", fmt.Errorf("bad SOCKS auth version") + } + usrLen := int(hdr[1]) + usrBytes := make([]byte, usrLen) + if _, err := io.ReadFull(r, usrBytes); err != nil { + return "", "", fmt.Errorf("could not read auth packet username") + } + var hdrPwd [1]byte + if _, err := io.ReadFull(r, hdrPwd[:]); err != nil { + return "", "", fmt.Errorf("could not read auth packet password length") + } + pwdLen := int(hdrPwd[0]) + pwdBytes := make([]byte, pwdLen) + if _, err := io.ReadFull(r, pwdBytes); err != nil { + return "", "", fmt.Errorf("could not read auth packet password") + } + return string(usrBytes), string(pwdBytes), nil +} + +func getAddrType(addr string) addrType { + if ip := net.ParseIP(addr); ip != nil { + if ip.To4() != nil { + return ipv4 + } + return ipv6 + } + return domainName +} + +// request represents data contained within a SOCKS5 +// connection request packet. +type request struct { + command commandType + destination socksAddr +} + +// parseClientRequest converts raw packet bytes into a +// SOCKS5Request struct. +func parseClientRequest(r io.Reader) (*request, error) { + var hdr [3]byte + _, err := io.ReadFull(r, hdr[:]) + if err != nil { + return nil, fmt.Errorf("could not read packet header") + } + cmd := hdr[1] + + destination, err := parseSocksAddr(r) + return &request{ + command: commandType(cmd), + destination: destination, + }, err +} + +type socksAddr struct { + addrType addrType + addr string + port uint16 +} + +var zeroSocksAddr = socksAddr{addrType: ipv4, addr: "0.0.0.0", port: 0} + +func parseSocksAddr(r io.Reader) (addr socksAddr, err error) { + var addrTypeData [1]byte + _, err = io.ReadFull(r, addrTypeData[:]) + if err != nil { + return socksAddr{}, fmt.Errorf("could not read address type") + } + + dstAddrType := addrType(addrTypeData[0]) + var destination string + switch dstAddrType { + case ipv4: + var ip [4]byte + _, err = io.ReadFull(r, ip[:]) + if err != nil { + return socksAddr{}, fmt.Errorf("could not read IPv4 address") + } + destination = net.IP(ip[:]).String() + case domainName: + var dstSizeByte [1]byte + _, err = io.ReadFull(r, dstSizeByte[:]) + if err != nil { + return socksAddr{}, fmt.Errorf("could not read domain name size") + } + dstSize := int(dstSizeByte[0]) + domainName := make([]byte, dstSize) + _, err = io.ReadFull(r, domainName) + if err != nil { + return socksAddr{}, fmt.Errorf("could not read domain name") + } + destination = string(domainName) + case ipv6: + var ip [16]byte + _, err = io.ReadFull(r, ip[:]) + if err != nil { + return socksAddr{}, fmt.Errorf("could not read IPv6 address") + } + destination = net.IP(ip[:]).String() + default: + return socksAddr{}, fmt.Errorf("unsupported address type") + } + var portBytes [2]byte + _, err = io.ReadFull(r, portBytes[:]) + if err != nil { + return socksAddr{}, fmt.Errorf("could not read port") + } + port := binary.BigEndian.Uint16(portBytes[:]) + return socksAddr{ + addrType: dstAddrType, + addr: destination, + port: port, + }, nil +} + +func (s socksAddr) marshal() ([]byte, error) { + var addr []byte + switch s.addrType { + case ipv4: + addr = net.ParseIP(s.addr).To4() + if addr == nil { + return nil, fmt.Errorf("invalid IPv4 address for binding") + } + case domainName: + if len(s.addr) > 255 { + return nil, fmt.Errorf("invalid domain name for binding") + } + addr = make([]byte, 0, len(s.addr)+1) + addr = append(addr, byte(len(s.addr))) + addr = append(addr, []byte(s.addr)...) + case ipv6: + addr = net.ParseIP(s.addr).To16() + if addr == nil { + return nil, fmt.Errorf("invalid IPv6 address for binding") + } + default: + return nil, fmt.Errorf("unsupported address type") + } + + pkt := []byte{byte(s.addrType)} + pkt = append(pkt, addr...) + pkt = binary.BigEndian.AppendUint16(pkt, s.port) + return pkt, nil +} +func (s socksAddr) hostPort() string { + return net.JoinHostPort(s.addr, strconv.Itoa(int(s.port))) +} + +// response contains the contents of +// a response packet sent from the proxy +// to the client. +type response struct { + reply replyCode + bindAddr socksAddr +} + +func errorResponse(code replyCode) *response { + return &response{reply: code, bindAddr: zeroSocksAddr} +} + +// marshal converts a SOCKS5Response struct into +// a packet. If res.reply == Success, it may throw an error on +// receiving an invalid bind address. Otherwise, it will not throw. +func (res *response) marshal() ([]byte, error) { + pkt := make([]byte, 3) + pkt[0] = socks5Version + pkt[1] = byte(res.reply) + pkt[2] = 0 // null reserved byte + + addrPkt, err := res.bindAddr.marshal() + if err != nil { + return nil, err + } + + return append(pkt, addrPkt...), nil +} + +type udpRequest struct { + frag byte + addr socksAddr +} + +// +----+------+------+----------+----------+----------+ +// |RSV | FRAG | ATYP | DST.ADDR | DST.PORT | DATA | +// +----+------+------+----------+----------+----------+ +// | 2 | 1 | 1 | Variable | 2 | Variable | +// +----+------+------+----------+----------+----------+ +func parseUDPRequest(data []byte) (*udpRequest, []byte, error) { + if len(data) < 4 { + return nil, nil, fmt.Errorf("invalid packet length") + } + + // reserved bytes + if !(data[0] == 0 && data[1] == 0) { + return nil, nil, fmt.Errorf("invalid udp request header") + } + + frag := data[2] + + reader := bytes.NewReader(data[3:]) + addr, err := parseSocksAddr(reader) + bodyLen := reader.Len() // (*bytes.Reader).Len() return unread data length + body := data[len(data)-bodyLen:] + return &udpRequest{ + frag: frag, + addr: addr, + }, body, err +} + +func (u *udpRequest) marshal() ([]byte, error) { + pkt := make([]byte, 3) + pkt[0] = 0 + pkt[1] = 0 + pkt[2] = u.frag + + addrPkt, err := u.addr.marshal() + if err != nil { + return nil, err + } + + return append(pkt, addrPkt...), nil +} diff --git a/internal/tunnel/tunnel.go b/internal/tunnel/tunnel.go index 658eb37..9dcbf98 100644 --- a/internal/tunnel/tunnel.go +++ b/internal/tunnel/tunnel.go @@ -9,7 +9,7 @@ import ( "time" "github.com/alebeck/boring/internal/log" - "github.com/armon/go-socks5" + "github.com/alebeck/boring/internal/proxy" "golang.org/x/crypto/ssh" ) @@ -196,16 +196,11 @@ func (t *Tunnel) tunnel(local, remote net.Conn) { } func (t *Tunnel) handleSocks() { - conf := &socks5.Config{ - Dial: func(ctx context.Context, network, addr string) (net.Conn, error) { + serv := &proxy.Server{ + Dialer: func(ctx context.Context, network, addr string) (net.Conn, error) { return t.client.Dial(network, addr) }, } - serv, err := socks5.New(conf) - if err != nil { - log.Errorf("could not create socks server: %v", err) - return - } for { conn, err := t.listener.Accept()