Skip to content

Commit

Permalink
Refactor(tunnel): modularize tunnel pkg (#393)
Browse files Browse the repository at this point in the history
  • Loading branch information
xjasonlyu authored Aug 31, 2024
1 parent 71c45ef commit fd98f65
Show file tree
Hide file tree
Showing 8 changed files with 152 additions and 75 deletions.
7 changes: 3 additions & 4 deletions engine/engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ import (
"github.com/xjasonlyu/tun2socks/v2/core/device"
"github.com/xjasonlyu/tun2socks/v2/core/option"
"github.com/xjasonlyu/tun2socks/v2/dialer"
"github.com/xjasonlyu/tun2socks/v2/engine/mirror"
"github.com/xjasonlyu/tun2socks/v2/log"
"github.com/xjasonlyu/tun2socks/v2/proxy"
"github.com/xjasonlyu/tun2socks/v2/restapi"
Expand Down Expand Up @@ -130,7 +129,7 @@ func general(k *Key) error {
if k.UDPTimeout < time.Second {
return errors.New("invalid udp timeout value")
}
tunnel.SetUDPTimeout(k.UDPTimeout)
tunnel.T().SetUDPTimeout(k.UDPTimeout)
}
return nil
}
Expand Down Expand Up @@ -192,7 +191,7 @@ func netstack(k *Key) (err error) {
if _defaultProxy, err = parseProxy(k.Proxy); err != nil {
return
}
proxy.SetDialer(_defaultProxy)
tunnel.T().SetDialer(_defaultProxy)

if _defaultDevice, err = parseDevice(k.Device, uint32(k.MTU)); err != nil {
return
Expand Down Expand Up @@ -226,7 +225,7 @@ func netstack(k *Key) (err error) {

if _defaultStack, err = core.CreateStack(&core.Config{
LinkEndpoint: _defaultDevice,
TransportHandler: &mirror.Tunnel{},
TransportHandler: tunnel.T(),
MulticastGroups: multicastGroups,
Options: opts,
}); err != nil {
Expand Down
18 changes: 0 additions & 18 deletions engine/mirror/tunnel.go

This file was deleted.

37 changes: 37 additions & 0 deletions tunnel/global.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
package tunnel

import (
"sync"

"github.com/xjasonlyu/tun2socks/v2/proxy"
"github.com/xjasonlyu/tun2socks/v2/tunnel/statistic"
)

var (
_globalMu sync.RWMutex
_globalT *Tunnel
)

func init() {
ReplaceGlobal(New(&proxy.Base{}, statistic.DefaultManager))
T().ProcessAsync()
}

// T returns the global Tunnel, which can be reconfigured with
// ReplaceGlobal. It's safe for concurrent use.
func T() *Tunnel {
_globalMu.RLock()
t := _globalT
_globalMu.RUnlock()
return t
}

// ReplaceGlobal replaces the global Tunnel, and returns a function
// to restore the original values. It's safe for concurrent use.
func ReplaceGlobal(t *Tunnel) func() {
_globalMu.Lock()
prev := _globalT
_globalT = t
_globalMu.Unlock()
return func() { ReplaceGlobal(prev) }
}
1 change: 0 additions & 1 deletion tunnel/statistic/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ func init() {
uploadTotal: atomic.NewInt64(0),
downloadTotal: atomic.NewInt64(0),
}

go DefaultManager.handle()
}

Expand Down
10 changes: 0 additions & 10 deletions tunnel/statistic/tracker.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,6 @@ func NewTCPTracker(conn net.Conn, metadata *M.Metadata, manager *Manager) net.Co
return tt
}

// DefaultTCPTracker returns a new net.Conn(*tcpTacker) with default manager.
func DefaultTCPTracker(conn net.Conn, metadata *M.Metadata) net.Conn {
return NewTCPTracker(conn, metadata, DefaultManager)
}

func (tt *tcpTracker) ID() string {
return tt.UUID.String()
}
Expand Down Expand Up @@ -120,11 +115,6 @@ func NewUDPTracker(conn net.PacketConn, metadata *M.Metadata, manager *Manager)
return ut
}

// DefaultUDPTracker returns a new net.PacketConn(*udpTacker) with default manager.
func DefaultUDPTracker(conn net.PacketConn, metadata *M.Metadata) net.PacketConn {
return NewUDPTracker(conn, metadata, DefaultManager)
}

func (ut *udpTracker) ID() string {
return ut.UUID.String()
}
Expand Down
18 changes: 8 additions & 10 deletions tunnel/tcp.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package tunnel

import (
"context"
"io"
"net"
"sync"
Expand All @@ -10,16 +11,10 @@ import (
"github.com/xjasonlyu/tun2socks/v2/core/adapter"
"github.com/xjasonlyu/tun2socks/v2/log"
M "github.com/xjasonlyu/tun2socks/v2/metadata"
"github.com/xjasonlyu/tun2socks/v2/proxy"
"github.com/xjasonlyu/tun2socks/v2/tunnel/statistic"
)

const (
// tcpWaitTimeout implements a TCP half-close timeout.
tcpWaitTimeout = 60 * time.Second
)

func handleTCPConn(originConn adapter.TCPConn) {
func (t *Tunnel) handleTCPConn(originConn adapter.TCPConn) {
defer originConn.Close()

id := originConn.ID()
Expand All @@ -31,21 +26,24 @@ func handleTCPConn(originConn adapter.TCPConn) {
DstPort: id.LocalPort,
}

remoteConn, err := proxy.Dial(metadata)
ctx, cancel := context.WithTimeout(context.Background(), tcpConnectTimeout)
defer cancel()

remoteConn, err := t.Dialer().DialContext(ctx, metadata)
if err != nil {
log.Warnf("[TCP] dial %s: %v", metadata.DestinationAddress(), err)
return
}
metadata.MidIP, metadata.MidPort = parseAddr(remoteConn.LocalAddr())

remoteConn = statistic.DefaultTCPTracker(remoteConn, metadata)
remoteConn = statistic.NewTCPTracker(remoteConn, metadata, t.manager)
defer remoteConn.Close()

log.Infof("[TCP] %s <-> %s", metadata.SourceAddress(), metadata.DestinationAddress())
pipe(originConn, remoteConn)
}

// pipe copies copy data to & from provided net.Conn(s) bidirectionally.
// pipe copies data to & from provided net.Conn(s) bidirectionally.
func pipe(origin, remote net.Conn) {
wg := sync.WaitGroup{}
wg.Add(2)
Expand Down
110 changes: 95 additions & 15 deletions tunnel/tunnel.go
Original file line number Diff line number Diff line change
@@ -1,36 +1,116 @@
package tunnel

import (
"context"
"sync"
"time"

"go.uber.org/atomic"

"github.com/xjasonlyu/tun2socks/v2/core/adapter"
"github.com/xjasonlyu/tun2socks/v2/proxy"
"github.com/xjasonlyu/tun2socks/v2/tunnel/statistic"
)

// Unbuffered TCP/UDP queues.
var (
_tcpQueue = make(chan adapter.TCPConn)
_udpQueue = make(chan adapter.UDPConn)
const (
// tcpConnectTimeout is the default timeout for TCP handshakes.
tcpConnectTimeout = 5 * time.Second
// tcpWaitTimeout implements a TCP half-close timeout.
tcpWaitTimeout = 60 * time.Second
// udpSessionTimeout is the default timeout for UDP sessions.
udpSessionTimeout = 60 * time.Second
)

func init() {
go process()
var _ adapter.TransportHandler = (*Tunnel)(nil)

type Tunnel struct {
// Unbuffered TCP/UDP queues.
tcpQueue chan adapter.TCPConn
udpQueue chan adapter.UDPConn

// UDP session timeout.
udpTimeout *atomic.Duration

// Internal proxy.Dialer for Tunnel.
dialerMu sync.RWMutex
dialer proxy.Dialer

// Where the Tunnel statistics are sent to.
manager *statistic.Manager

procOnce sync.Once
procCancel context.CancelFunc
}

func New(dialer proxy.Dialer, manager *statistic.Manager) *Tunnel {
return &Tunnel{
tcpQueue: make(chan adapter.TCPConn),
udpQueue: make(chan adapter.UDPConn),
udpTimeout: atomic.NewDuration(udpSessionTimeout),
dialer: dialer,
manager: manager,
procCancel: func() { /* nop */ },
}
}

// TCPIn return fan-in TCP queue.
func TCPIn() chan<- adapter.TCPConn {
return _tcpQueue
func (t *Tunnel) TCPIn() chan<- adapter.TCPConn {
return t.tcpQueue
}

// UDPIn return fan-in UDP queue.
func UDPIn() chan<- adapter.UDPConn {
return _udpQueue
func (t *Tunnel) UDPIn() chan<- adapter.UDPConn {
return t.udpQueue
}

func (t *Tunnel) HandleTCP(conn adapter.TCPConn) {
t.TCPIn() <- conn
}

func process() {
func (t *Tunnel) HandleUDP(conn adapter.UDPConn) {
t.UDPIn() <- conn
}

func (t *Tunnel) process(ctx context.Context) {
for {
select {
case conn := <-_tcpQueue:
go handleTCPConn(conn)
case conn := <-_udpQueue:
go handleUDPConn(conn)
case conn := <-t.tcpQueue:
go t.handleTCPConn(conn)
case conn := <-t.udpQueue:
go t.handleUDPConn(conn)
case <-ctx.Done():
return
}
}
}

// ProcessAsync can be safely called multiple times, but will only be effective once.
func (t *Tunnel) ProcessAsync() {
t.procOnce.Do(func() {
ctx, cancel := context.WithCancel(context.Background())
t.procCancel = cancel
go t.process(ctx)
})
}

// Close closes the Tunnel and releases its resources.
func (t *Tunnel) Close() {
t.procCancel()
}

func (t *Tunnel) Dialer() proxy.Dialer {
t.dialerMu.RLock()
d := t.dialer
t.dialerMu.RUnlock()
return d
}

func (t *Tunnel) SetDialer(dialer proxy.Dialer) {
t.dialerMu.Lock()
t.dialer = dialer
t.dialerMu.Unlock()
}

func (t *Tunnel) SetUDPTimeout(timeout time.Duration) {
t.udpTimeout.Store(timeout)
}
26 changes: 9 additions & 17 deletions tunnel/udp.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,19 +10,11 @@ import (
"github.com/xjasonlyu/tun2socks/v2/core/adapter"
"github.com/xjasonlyu/tun2socks/v2/log"
M "github.com/xjasonlyu/tun2socks/v2/metadata"
"github.com/xjasonlyu/tun2socks/v2/proxy"
"github.com/xjasonlyu/tun2socks/v2/tunnel/statistic"
)

// _udpSessionTimeout is the default timeout for each UDP session.
var _udpSessionTimeout = 60 * time.Second

func SetUDPTimeout(t time.Duration) {
_udpSessionTimeout = t
}

// TODO: Port Restricted NAT support.
func handleUDPConn(uc adapter.UDPConn) {
func (t *Tunnel) handleUDPConn(uc adapter.UDPConn) {
defer uc.Close()

id := uc.ID()
Expand All @@ -34,14 +26,14 @@ func handleUDPConn(uc adapter.UDPConn) {
DstPort: id.LocalPort,
}

pc, err := proxy.DialUDP(metadata)
pc, err := t.Dialer().DialUDP(metadata)
if err != nil {
log.Warnf("[UDP] dial %s: %v", metadata.DestinationAddress(), err)
return
}
metadata.MidIP, metadata.MidPort = parseAddr(pc.LocalAddr())

pc = statistic.DefaultUDPTracker(pc, metadata)
pc = statistic.NewUDPTracker(pc, metadata, t.manager)
defer pc.Close()

var remote net.Addr
Expand All @@ -53,22 +45,22 @@ func handleUDPConn(uc adapter.UDPConn) {
pc = newSymmetricNATPacketConn(pc, metadata)

log.Infof("[UDP] %s <-> %s", metadata.SourceAddress(), metadata.DestinationAddress())
pipePacket(uc, pc, remote)
pipePacket(uc, pc, remote, t.udpTimeout.Load())
}

func pipePacket(origin, remote net.PacketConn, to net.Addr) {
func pipePacket(origin, remote net.PacketConn, to net.Addr, timeout time.Duration) {
wg := sync.WaitGroup{}
wg.Add(2)

go unidirectionalPacketStream(remote, origin, to, "origin->remote", &wg)
go unidirectionalPacketStream(origin, remote, nil, "remote->origin", &wg)
go unidirectionalPacketStream(remote, origin, to, "origin->remote", &wg, timeout)
go unidirectionalPacketStream(origin, remote, nil, "remote->origin", &wg, timeout)

wg.Wait()
}

func unidirectionalPacketStream(dst, src net.PacketConn, to net.Addr, dir string, wg *sync.WaitGroup) {
func unidirectionalPacketStream(dst, src net.PacketConn, to net.Addr, dir string, wg *sync.WaitGroup, timeout time.Duration) {
defer wg.Done()
if err := copyPacketData(dst, src, to, _udpSessionTimeout); err != nil {
if err := copyPacketData(dst, src, to, timeout); err != nil {
log.Debugf("[UDP] copy data for %s: %v", dir, err)
}
}
Expand Down

0 comments on commit fd98f65

Please sign in to comment.