Skip to content

Commit

Permalink
udpmux: assoc dst network; netip as route key
Browse files Browse the repository at this point in the history
  • Loading branch information
ignoramous committed Aug 23, 2024
1 parent e4adcc1 commit f27b7a7
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 73 deletions.
10 changes: 3 additions & 7 deletions intra/udp.go
Original file line number Diff line number Diff line change
Expand Up @@ -322,13 +322,9 @@ func (h *udpHandler) Connect(gconn *netstack.GUDPConn, src, target netip.AddrPor
// note: fake-dns-ips shouldn't be un-nated / un-alg'd
for i, dstipp := range makeIPPorts(realips, target, 0) {
selectedTarget = dstipp
if mux { // pc is net.PacketConn (which confirms to core.UDPConn)
var mxr *muxer
// mux not supported by all proxies (few like Exit, Base, WG support it)
if mxr, err = h.mux.associate(cid, src, px.Dialer().Announce, dmx); err == nil {
pc, err = mxr.vend(net.UDPAddrFromAddrPort(selectedTarget))
}
} else { // pc is net.Conn (which confirms to core.UDPConn)
if mux { // mux is not supported by all proxies (few like Exit, Base, WG support it)
pc, err = h.mux.associate(cid, src, selectedTarget, px.Dialer().Announce, dmx)
} else {
pc, err = px.Dialer().Dial("udp", selectedTarget.String())
}
if err == nil {
Expand Down
144 changes: 78 additions & 66 deletions intra/udpmux.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,6 @@ const (
maxtimeouterrors = 3
)

var (
errMuxerDone = errors.New("udp: muxer closed")
)

type sender interface {
id() string
sendto([]byte, net.Addr) (int, error)
Expand Down Expand Up @@ -65,17 +61,18 @@ type muxer struct {
cb func() // muxer.stop() callback (in a new goroutine)
vnd netstack.DemuxerFn // for new routes in netstack

rmu sync.Mutex // protects routes
routes map[string]*demuxconn // remote addr -> demuxed conn
rmu sync.Mutex // protects routes
routes map[netip.AddrPort]*demuxconn // remote addr -> demuxed conn

dxconnWG *sync.WaitGroup // wait group for demuxed conns
}

// demuxconn writes to addr and reads from the muxer
type demuxconn struct {
remux sender // promiscuous sender
raddr net.Addr // remote address connected to
laddr net.Addr // local address connected from
remux sender // promiscuous sender
key netip.AddrPort // promiscuous factor (same as raddr)
raddr net.Addr // remote address connected to
laddr net.Addr // local address connected from

incomingCh chan *slice // incoming data, never closed
overflowCh chan *slice // overflow data, never closed
Expand Down Expand Up @@ -104,7 +101,7 @@ func newMuxerLocked(id string, conn net.PacketConn, vnd netstack.DemuxerFn, f fu
cid: id,
mxconn: conn,
stats: &stats{start: time.Now()},
routes: make(map[string]*demuxconn),
routes: make(map[netip.AddrPort]*demuxconn),
rmu: sync.Mutex{},
dxconns: make(chan *demuxconn),
doneCh: make(chan struct{}),
Expand Down Expand Up @@ -206,60 +203,60 @@ func (x *muxer) readers() {
log.W("udp: mux: %s read done n(%d): nil remote addr; skip", x.cid, n)
continue
}

if dst, err := x.route(who); err != nil {
// route fails if muxer.dxconns is closed (which is never closed)
log.W("udp: mux: %s new route err: %v", x.cid, err)
return
} else { // may be existing route or a new route
// may be existing route or a new route
if dst := x.route(addr2netip(who)); dst != nil {
select {
case dst.incomingCh <- &slice{v: b[:n], free: free}: // incomingCh is never closed
default: // dst probably closed, but not yet unrouted
log.W("udp: mux: %s read: drop(sz: %d); route to %s", x.cid, n, dst.raddr)
}
}
} // else: ignore (who is invalid or x is closed)
}
}

func (x *muxer) route(raddr net.Addr) (*demuxconn, error) {
func (x *muxer) route(to netip.AddrPort) *demuxconn {
x.rmu.Lock()
defer x.rmu.Unlock()

addr := raddr.String() // raddr must never be nil
conn, ok := x.routes[addr]
if !ok || conn == nil {
if !to.IsValid() {
log.W("udp: mux: %s route: invalid addr %s", x.cid, to)
return nil
}

conn := x.routes[to]
if conn == nil {
// new routes created here won't really exist in netstack if
// settings.EndpointIndependentMapping or settings.EndpointIndependentFiltering
// is set to false.
conn = x.newLocked(raddr)
conn = x.newLocked(to)
select {
case <-x.doneCh:
clos(conn)
return nil, errMuxerDone
log.W("udp: mux: %s route: for %s; muxer closed", x.cid, to)
return nil
case x.dxconns <- conn:
x.stats.dxcount.Add(1)
x.routes[addr] = conn
if dst, err := addr2netip(raddr); err == nil && dst.IsValid() {
go x.vnd(dst)
} else { // should never happen
log.E("udp: mux: %s route: invalid addr %s; err: %v", x.cid, raddr, err)
}
log.I("udp: mux: %s route: new for %s; stats: %d",
x.cid, raddr, x.stats)
x.routes[to] = conn
core.Go("udpmux.vend", func() {
verr := x.vnd(to) // a fork in the road
if verr != nil {
clos(conn)
log.E("udp: mux: %s route: vend failure %s; err %v", x.cid, to, verr)
}
})
log.I("udp: mux: %s route: new for %s; stats: %d", x.cid, to, x.stats)
}
}
return conn, nil
return conn
}

func (x *muxer) unroute(cc ...*demuxconn) {
func (x *muxer) unroute(c *demuxconn) {
// don't really expect to handle panic w/ core.Recover
x.rmu.Lock()
defer x.rmu.Unlock()

for _, c := range cc {
log.I("udp: mux: %s unrouting... %s => %s", x.cid, c.laddr, c.raddr)
delete(x.routes, c.raddr.String())
}
log.I("udp: mux: %s unrouting... %s => %s", x.cid, c.laddr, c.raddr)
delete(x.routes, c.key)
}

func (x *muxer) id() string { return x.cid }
Expand All @@ -285,14 +282,15 @@ func (x *muxer) extend(t time.Time) {
}

// new creates a demuxed conn to r.
func (x *muxer) newLocked(r net.Addr) *demuxconn {
func (x *muxer) newLocked(r netip.AddrPort) *demuxconn {
return &demuxconn{
remux: x, // muxer
laddr: x.mxconn.LocalAddr(), // listen addr
raddr: r, // sendto addr
incomingCh: make(chan *slice, 32), // read from muxer
overflowCh: make(chan *slice, 16), // overflow from read
closed: make(chan struct{}), // always unbuffered
remux: x, // muxer
laddr: x.mxconn.LocalAddr(), // listen addr
raddr: net.UDPAddrFromAddrPort(r), // remote addr
key: r, // key (same as raddr)
incomingCh: make(chan *slice, 32), // read from muxer
overflowCh: make(chan *slice, 16), // overflow from read
closed: make(chan struct{}), // always unbuffered
wt: time.NewTicker(udptimeout),
rt: time.NewTicker(udptimeout),
wto: udptimeout,
Expand All @@ -301,10 +299,11 @@ func (x *muxer) newLocked(r net.Addr) *demuxconn {
}

// TODO: make sure a conn can only be vend once
func (x *muxer) vend(dst net.Addr) (net.Conn, error) {
c, err := x.route(dst)
if err != nil {
return nil, err
func (x *muxer) vend(dst netip.AddrPort) (net.Conn, error) {
c := x.route(dst)
if c == nil {
log.E("udp: mux: %s vend: no conn for %s", x.cid, dst)
return nil, errUdpSetupConn
}
return c, nil
}
Expand Down Expand Up @@ -337,17 +336,18 @@ func (c *demuxconn) Write(p []byte) (n int, err error) {
}
}

// ReadFrom implements core.UDPConn.ReadFrom
// ReadFrom implements core.UDPConn.ReadFrom (unused)
func (c *demuxconn) ReadFrom(p []byte) (int, net.Addr, error) {
n, err := c.Read(p)
return n, c.raddr, err
}

// WriteTo implements core.UDPConn.WriteTo
// WriteTo implements core.UDPConn.WriteTo (unused)
func (c *demuxconn) WriteTo(p []byte, to net.Addr) (int, error) {
if to.String() != c.raddr.String() {
return 0, net.ErrWriteToConnected
}
// todo: check if "to" is the same as c.raddr
// if to != c.raddr {
// return 0, net.ErrWriteToConnected
// }
return c.Write(p)
}

Expand All @@ -371,7 +371,6 @@ func (c *demuxconn) Close() error {
}
}
})

return nil
}

Expand Down Expand Up @@ -457,24 +456,29 @@ func newMuxTable() *muxTable {
return &muxTable{t: make(map[netip.AddrPort]*muxer)}
}

func (e *muxTable) associate(id string, src netip.AddrPort, mk assocFn, v netstack.DemuxerFn) (*muxer, error) {
func (e *muxTable) associate(id string, src, dst netip.AddrPort, mk assocFn, v netstack.DemuxerFn) (c net.Conn, err error) {
e.Lock()
defer e.Unlock()

proto, anyaddr := anyaddrFor(src)
if mxr, ok := e.t[src]; ok {
return mxr, nil
} else if pc, err := mk(proto, anyaddr); err == nil {
log.I("udp: mux: %s new assoc for %s", id, src)
var mxr *muxer
// dst may be of a different family than src (4to6, 6to4 etc)
// and so, rely on dst to determine the family to listen on.
proto, anyaddr := anyaddrFor(dst)
mxr = e.t[src]
if mxr == nil {
var pc net.PacketConn
pc, err = mk(proto, anyaddr)
if err != nil {
core.Close(pc)
return nil, err
}
mxr = newMuxerLocked(id, pc, v, func() {
e.dissociate(id, src)
})
e.t[src] = mxr
return mxr, nil
} else {
core.Close(pc)
return nil, err
log.I("udp: mux: %s new assoc for %s", id, src)
}
return mxr.vend(dst)
}

func (e *muxTable) dissociate(id string, src netip.AddrPort) {
Expand All @@ -485,6 +489,14 @@ func (e *muxTable) dissociate(id string, src netip.AddrPort) {
delete(e.t, src)
}

func addr2netip(addr net.Addr) (netip.AddrPort, error) {
return netip.ParseAddrPort(addr.String())
func addr2netip(addr net.Addr) (zz netip.AddrPort) {
if addr == nil {
return // zz
}
ipp, err := netip.ParseAddrPort(addr.String())
if err == nil {
log.W("udp: mux: addr2netip: %v", err)
return // zz
}
return ipp // may be invalid
}

1 comment on commit f27b7a7

@ignoramous
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#77

Please sign in to comment.