From f27b7a7d308816e01dd9f237e0bcfae4824cf3ad Mon Sep 17 00:00:00 2001 From: Murtaza Aliakbar Date: Sat, 24 Aug 2024 03:18:05 +0530 Subject: [PATCH] udpmux: assoc dst network; netip as route key --- intra/udp.go | 10 +--- intra/udpmux.go | 144 ++++++++++++++++++++++++++---------------------- 2 files changed, 81 insertions(+), 73 deletions(-) diff --git a/intra/udp.go b/intra/udp.go index ed783e6c..fa8dfc7f 100644 --- a/intra/udp.go +++ b/intra/udp.go @@ -322,13 +322,9 @@ func (h *udpHandler) Connect(gconn *netstack.GUDPConn, src, target netip.AddrPor // note: fake-dns-ips shouldn't be un-nated / un-alg'd for i, dstipp := range makeIPPorts(realips, target, 0) { selectedTarget = dstipp - if mux { // pc is net.PacketConn (which confirms to core.UDPConn) - var mxr *muxer - // mux not supported by all proxies (few like Exit, Base, WG support it) - if mxr, err = h.mux.associate(cid, src, px.Dialer().Announce, dmx); err == nil { - pc, err = mxr.vend(net.UDPAddrFromAddrPort(selectedTarget)) - } - } else { // pc is net.Conn (which confirms to core.UDPConn) + if mux { // mux is not supported by all proxies (few like Exit, Base, WG support it) + pc, err = h.mux.associate(cid, src, selectedTarget, px.Dialer().Announce, dmx) + } else { pc, err = px.Dialer().Dial("udp", selectedTarget.String()) } if err == nil { diff --git a/intra/udpmux.go b/intra/udpmux.go index f63adf20..00918d0a 100644 --- a/intra/udpmux.go +++ b/intra/udpmux.go @@ -27,10 +27,6 @@ const ( maxtimeouterrors = 3 ) -var ( - errMuxerDone = errors.New("udp: muxer closed") -) - type sender interface { id() string sendto([]byte, net.Addr) (int, error) @@ -65,17 +61,18 @@ type muxer struct { cb func() // muxer.stop() callback (in a new goroutine) vnd netstack.DemuxerFn // for new routes in netstack - rmu sync.Mutex // protects routes - routes map[string]*demuxconn // remote addr -> demuxed conn + rmu sync.Mutex // protects routes + routes map[netip.AddrPort]*demuxconn // remote addr -> demuxed conn dxconnWG *sync.WaitGroup // wait group for demuxed conns } // demuxconn writes to addr and reads from the muxer type demuxconn struct { - remux sender // promiscuous sender - raddr net.Addr // remote address connected to - laddr net.Addr // local address connected from + remux sender // promiscuous sender + key netip.AddrPort // promiscuous factor (same as raddr) + raddr net.Addr // remote address connected to + laddr net.Addr // local address connected from incomingCh chan *slice // incoming data, never closed overflowCh chan *slice // overflow data, never closed @@ -104,7 +101,7 @@ func newMuxerLocked(id string, conn net.PacketConn, vnd netstack.DemuxerFn, f fu cid: id, mxconn: conn, stats: &stats{start: time.Now()}, - routes: make(map[string]*demuxconn), + routes: make(map[netip.AddrPort]*demuxconn), rmu: sync.Mutex{}, dxconns: make(chan *demuxconn), doneCh: make(chan struct{}), @@ -206,60 +203,60 @@ func (x *muxer) readers() { log.W("udp: mux: %s read done n(%d): nil remote addr; skip", x.cid, n) continue } - - if dst, err := x.route(who); err != nil { - // route fails if muxer.dxconns is closed (which is never closed) - log.W("udp: mux: %s new route err: %v", x.cid, err) - return - } else { // may be existing route or a new route + // may be existing route or a new route + if dst := x.route(addr2netip(who)); dst != nil { select { case dst.incomingCh <- &slice{v: b[:n], free: free}: // incomingCh is never closed default: // dst probably closed, but not yet unrouted log.W("udp: mux: %s read: drop(sz: %d); route to %s", x.cid, n, dst.raddr) } - } + } // else: ignore (who is invalid or x is closed) } } -func (x *muxer) route(raddr net.Addr) (*demuxconn, error) { +func (x *muxer) route(to netip.AddrPort) *demuxconn { x.rmu.Lock() defer x.rmu.Unlock() - addr := raddr.String() // raddr must never be nil - conn, ok := x.routes[addr] - if !ok || conn == nil { + if !to.IsValid() { + log.W("udp: mux: %s route: invalid addr %s", x.cid, to) + return nil + } + + conn := x.routes[to] + if conn == nil { // new routes created here won't really exist in netstack if // settings.EndpointIndependentMapping or settings.EndpointIndependentFiltering // is set to false. - conn = x.newLocked(raddr) + conn = x.newLocked(to) select { case <-x.doneCh: clos(conn) - return nil, errMuxerDone + log.W("udp: mux: %s route: for %s; muxer closed", x.cid, to) + return nil case x.dxconns <- conn: x.stats.dxcount.Add(1) - x.routes[addr] = conn - if dst, err := addr2netip(raddr); err == nil && dst.IsValid() { - go x.vnd(dst) - } else { // should never happen - log.E("udp: mux: %s route: invalid addr %s; err: %v", x.cid, raddr, err) - } - log.I("udp: mux: %s route: new for %s; stats: %d", - x.cid, raddr, x.stats) + x.routes[to] = conn + core.Go("udpmux.vend", func() { + verr := x.vnd(to) // a fork in the road + if verr != nil { + clos(conn) + log.E("udp: mux: %s route: vend failure %s; err %v", x.cid, to, verr) + } + }) + log.I("udp: mux: %s route: new for %s; stats: %d", x.cid, to, x.stats) } } - return conn, nil + return conn } -func (x *muxer) unroute(cc ...*demuxconn) { +func (x *muxer) unroute(c *demuxconn) { // don't really expect to handle panic w/ core.Recover x.rmu.Lock() defer x.rmu.Unlock() - for _, c := range cc { - log.I("udp: mux: %s unrouting... %s => %s", x.cid, c.laddr, c.raddr) - delete(x.routes, c.raddr.String()) - } + log.I("udp: mux: %s unrouting... %s => %s", x.cid, c.laddr, c.raddr) + delete(x.routes, c.key) } func (x *muxer) id() string { return x.cid } @@ -285,14 +282,15 @@ func (x *muxer) extend(t time.Time) { } // new creates a demuxed conn to r. -func (x *muxer) newLocked(r net.Addr) *demuxconn { +func (x *muxer) newLocked(r netip.AddrPort) *demuxconn { return &demuxconn{ - remux: x, // muxer - laddr: x.mxconn.LocalAddr(), // listen addr - raddr: r, // sendto addr - incomingCh: make(chan *slice, 32), // read from muxer - overflowCh: make(chan *slice, 16), // overflow from read - closed: make(chan struct{}), // always unbuffered + remux: x, // muxer + laddr: x.mxconn.LocalAddr(), // listen addr + raddr: net.UDPAddrFromAddrPort(r), // remote addr + key: r, // key (same as raddr) + incomingCh: make(chan *slice, 32), // read from muxer + overflowCh: make(chan *slice, 16), // overflow from read + closed: make(chan struct{}), // always unbuffered wt: time.NewTicker(udptimeout), rt: time.NewTicker(udptimeout), wto: udptimeout, @@ -301,10 +299,11 @@ func (x *muxer) newLocked(r net.Addr) *demuxconn { } // TODO: make sure a conn can only be vend once -func (x *muxer) vend(dst net.Addr) (net.Conn, error) { - c, err := x.route(dst) - if err != nil { - return nil, err +func (x *muxer) vend(dst netip.AddrPort) (net.Conn, error) { + c := x.route(dst) + if c == nil { + log.E("udp: mux: %s vend: no conn for %s", x.cid, dst) + return nil, errUdpSetupConn } return c, nil } @@ -337,17 +336,18 @@ func (c *demuxconn) Write(p []byte) (n int, err error) { } } -// ReadFrom implements core.UDPConn.ReadFrom +// ReadFrom implements core.UDPConn.ReadFrom (unused) func (c *demuxconn) ReadFrom(p []byte) (int, net.Addr, error) { n, err := c.Read(p) return n, c.raddr, err } -// WriteTo implements core.UDPConn.WriteTo +// WriteTo implements core.UDPConn.WriteTo (unused) func (c *demuxconn) WriteTo(p []byte, to net.Addr) (int, error) { - if to.String() != c.raddr.String() { - return 0, net.ErrWriteToConnected - } + // todo: check if "to" is the same as c.raddr + // if to != c.raddr { + // return 0, net.ErrWriteToConnected + // } return c.Write(p) } @@ -371,7 +371,6 @@ func (c *demuxconn) Close() error { } } }) - return nil } @@ -457,24 +456,29 @@ func newMuxTable() *muxTable { return &muxTable{t: make(map[netip.AddrPort]*muxer)} } -func (e *muxTable) associate(id string, src netip.AddrPort, mk assocFn, v netstack.DemuxerFn) (*muxer, error) { +func (e *muxTable) associate(id string, src, dst netip.AddrPort, mk assocFn, v netstack.DemuxerFn) (c net.Conn, err error) { e.Lock() defer e.Unlock() - proto, anyaddr := anyaddrFor(src) - if mxr, ok := e.t[src]; ok { - return mxr, nil - } else if pc, err := mk(proto, anyaddr); err == nil { - log.I("udp: mux: %s new assoc for %s", id, src) + var mxr *muxer + // dst may be of a different family than src (4to6, 6to4 etc) + // and so, rely on dst to determine the family to listen on. + proto, anyaddr := anyaddrFor(dst) + mxr = e.t[src] + if mxr == nil { + var pc net.PacketConn + pc, err = mk(proto, anyaddr) + if err != nil { + core.Close(pc) + return nil, err + } mxr = newMuxerLocked(id, pc, v, func() { e.dissociate(id, src) }) e.t[src] = mxr - return mxr, nil - } else { - core.Close(pc) - return nil, err + log.I("udp: mux: %s new assoc for %s", id, src) } + return mxr.vend(dst) } func (e *muxTable) dissociate(id string, src netip.AddrPort) { @@ -485,6 +489,14 @@ func (e *muxTable) dissociate(id string, src netip.AddrPort) { delete(e.t, src) } -func addr2netip(addr net.Addr) (netip.AddrPort, error) { - return netip.ParseAddrPort(addr.String()) +func addr2netip(addr net.Addr) (zz netip.AddrPort) { + if addr == nil { + return // zz + } + ipp, err := netip.ParseAddrPort(addr.String()) + if err == nil { + log.W("udp: mux: addr2netip: %v", err) + return // zz + } + return ipp // may be invalid }