diff --git a/intra/netstack/udp.go b/intra/netstack/udp.go index 2c93d4d3..a0e66945 100644 --- a/intra/netstack/udp.go +++ b/intra/netstack/udp.go @@ -23,13 +23,20 @@ import ( "gvisor.dev/gvisor/pkg/waiter" ) -var errMissingEp = errors.New("not connected to any endpoint") +var ( + errMissingEp = errors.New("not connected to any endpoint") + errMissingReq = errors.New("missing forwarder request") + errFilteredOut = errors.New("no eif; filtered out") +) + +type DemuxerFn func(dst netip.AddrPort) error type GUDPConnHandler interface { // Proxy proxies data between conn (src) and dst. Proxy(conn *GUDPConn, src, dst netip.AddrPort) bool - // ProxyMux proxies data between conn and multiple destinations. - ProxyMux(conn *GUDPConn, src, dst netip.AddrPort) bool + // ProxyMux proxies data between conn and multiple destinations + // (endpoint-independent mapping). + ProxyMux(conn *GUDPConn, src, dst netip.AddrPort, dmx DemuxerFn) bool // Error notes the error in connecting src to dst. Error(conn *GUDPConn, src, dst netip.AddrPort, err error) // CloseConns closes conns by ids, or all if ids is empty. @@ -42,10 +49,17 @@ var _ core.UDPConn = (*GUDPConn)(nil) type GUDPConn struct { stack *stack.Stack - c *core.Volatile[*gonet.UDPConn] // conn exposes UDP semantics atop endpoint - src netip.AddrPort // local addr (remote addr in netstack) - dst netip.AddrPort // remote addr (local addr in netstack) - req *udp.ForwarderRequest // egress request as UDP + + // conn exposes UDP semantics atop endpoint + c *core.Volatile[*gonet.UDPConn] + // local addr (remote addr in netstack) + // ex: 10.111.222.1:20716; same as endpoint.GetRemoteAddress + src netip.AddrPort + // remote addr (local addr in netstack) + // ex: 10.111.222.3:53; same as endpoint.GetLocalAddress + dst netip.AddrPort + + req *udp.ForwarderRequest // egress request as UDP eim bool // endpoint is muxed eif bool // endpoint is transparent @@ -85,6 +99,21 @@ func udpForwarder(s *stack.Stack, h GUDPConnHandler) *udp.Forwarder { log.E("ns: udp: forwarder: nil request") return } + + // owner app tun ns h + // repr socket packet endpoint socket + // type udp fd gudpconn core.minconn + // + // (src, dst) :1111, :53 :1111, :53 :53, :1111 :9999, :53 + // + // write :1111 => :53 :1111, :53 :53 => :1111 :9999 => :53 + // \ / + // \ / + // (pipe) \ / + // / \ + // / \ + // / \ + // read :1111 <= :53 :1111, :53 :53 <= :1111 :9999 <= :53 id := req.ID() // src 10.111.222.1:20716; same as endpoint.GetRemoteAddress src := remoteAddrPort(id) @@ -105,10 +134,30 @@ func udpForwarder(s *stack.Stack, h GUDPConnHandler) *udp.Forwarder { } } + demux := func(newdst netip.AddrPort) error { + if newdst == dst { + log.D("ns: udp: demuxer: no-op; src(%v) same as dst(%v)", src, newdst) + return nil + } + if !gc.eif { + return errFilteredOut + } + newgc := makeGUDPConn(s, nil /*not a forwarder req*/, src, newdst) + if !settings.SingleThreaded.Load() { + if err := newgc.Establish(); err != nil { + log.E("ns: udp: demuxer: dial: %v; src(%v) dst(%v)", err, src, newdst) + go h.Error(newgc, src, newdst, err) + return err + } + } + go h.Proxy(newgc, src, newdst) + return nil + } + // proxy in a separate gorountine; return immediately // why? netstack/dispatcher.go:newReadvDispatcher if gc.eim { - go h.ProxyMux(gc, src, dst) + go h.ProxyMux(gc, src, dst, demux) } else { go h.Proxy(gc, src, dst) } @@ -124,47 +173,35 @@ func (g *GUDPConn) conn() *gonet.UDPConn { } func (g *GUDPConn) StatefulTeardown() (fin bool) { - _ = g.tryConnect() // establish circuit then teardown - _ = g.Close() // then shutdown - return true // always fin + _ = g.Establish() // establish circuit then teardown + _ = g.Close() // then shutdown + return true // always fin } func (g *GUDPConn) Establish() error { - if g.eif { - return g.tryBind() - } - return g.tryConnect() -} - -func (g *GUDPConn) tryConnect() error { - if g.ok() { // already setup - return nil - } - - wq := new(waiter.Queue) - if endpoint, err := g.req.CreateEndpoint(wq); err != nil { - // ex: CONNECT endpoint for [fd66:f83a:c650::1]:15753 => [fd66:f83a:c650::3]:53; err(no route to host) - log.E("ns: udp: connect: endpoint for %v => %v; err(%v)", g.src, g.dst, err) - return e(err) - } else { - g.c.Store(gonet.NewUDPConn(wq, endpoint)) - } - return nil -} - -func (g *GUDPConn) tryBind() error { if g.ok() { // already setup return nil } - src, proto := addrport2nsaddr(g.src) - // unconnected socket w/ gonet.DialUDP - if conn, err := gonet.DialUDP(g.stack, &src, nil, proto); err != nil { - log.E("ns: udp: bind: endpoint for %v [=> %v]; err(%v)", g.src, g.dst, err) - return err + if g.req == nil { + src, proto := addrport2nsaddr(g.dst) // remote addr is local addr in netstack + dst, _ := addrport2nsaddr(g.src) // local addr is remote addr in netstack + // ingress socket w/ gonet.DialUDP + if conn, err := gonet.DialUDP(g.stack, &src, &dst, proto); err != nil { + log.E("ns: udp: dial: endpoint for %v => %v; err(%v)", g.src, g.dst, err) + return err + } else { + g.c.Store(conn) + } } else { - // todo: handle the first pkt like in g.req.CreateEndpoint - g.c.Store(conn) + wq := new(waiter.Queue) + if endpoint, err := g.req.CreateEndpoint(wq); err != nil { + // ex: CONNECT endpoint for [fd66:f83a:c650::1]:15753 => [fd66:f83a:c650::3]:53; err(no route to host) + log.E("ns: udp: connect: endpoint for %v => %v; err(%v)", g.src, g.dst, err) + return e(err) + } else { + g.c.Store(gonet.NewUDPConn(wq, endpoint)) + } } return nil } @@ -196,7 +233,14 @@ func (g *GUDPConn) Write(data []byte) (int, error) { // ep(state 3 / info &{2048 17 {53 10.111.222.3 17711 10.111.222.1} 1 10.111.222.3 1} / stats &{{{1}} {{0}} {{{0}} {{0}} {{0}} {{0}}} {{{0}} {{0}} {{0}}} {{{0}} {{0}}} {{{0}} {{0}} {{0}}}}) // 3: status:datagram-connected / {2048=>proto, 17=>transport, {53=>local-port localip 17711=>remote-port remoteip}=>endpoint-id, 1=>bind-nic-id, ip=>bind-addr, 1=>registered-nic-id} // g.ep may be nil: log.V("ns: writeFrom: from(%v) / ep(state %v / info %v / stats %v)", addr, g.ep.State(), g.ep.Info(), g.ep.Stats()) - return c.Write(data) + if g.eif { + // unexpected except in cases of DNS override; + // forward the packet to the dst as got from the first pkt + log.W("ns: udp: Write(To): unexpected; %s <= %s; sz: %d", g.src, g.dst, len(data)) + return c.WriteTo(data, net.UDPAddrFromAddrPort(g.dst)) + } else { + return c.Write(data) + } } return 0, netError(g, "udp", "write", io.ErrClosedPipe) } diff --git a/intra/tcp.go b/intra/tcp.go index 723abbad..1cb364ba 100644 --- a/intra/tcp.go +++ b/intra/tcp.go @@ -73,7 +73,7 @@ const ( ) const ( - retrytimeout = 1 * time.Minute + retrytimeout = 15 * time.Second onFlowTimeout = 5 * time.Second ) diff --git a/intra/udp.go b/intra/udp.go index e35bc610..ed783e6c 100644 --- a/intra/udp.go +++ b/intra/udp.go @@ -27,7 +27,6 @@ package intra import ( "errors" - "io" "net" "net/netip" "sync" @@ -176,32 +175,48 @@ func (h *udpHandler) onFlow(localaddr, target netip.AddrPort, realips, domains, } // ProxyMux implements netstack.GUDPConnHandler -func (h *udpHandler) ProxyMux(gconn *netstack.GUDPConn, src, dst netip.AddrPort) (ok bool) { +func (h *udpHandler) ProxyMux(gconn *netstack.GUDPConn, src, dst netip.AddrPort, dmx netstack.DemuxerFn) (ok bool) { defer core.Recover(core.Exit11, "udp.ProxyMux") - return h.proxy(gconn, src, dst, true) + return h.proxy(gconn, src, dst, dmx) } // Error implements netstack.GUDPConnHandler. // Must be called from a goroutine. -func (h *udpHandler) Error(gconn *netstack.GUDPConn, src, dst netip.AddrPort, err error) { - ok := h.proxy(gconn, src, dst, false) - log.I("udp: proxy: %v -> %v; err %v; recovered? %t", src, dst, err, ok) +func (h *udpHandler) Error(gconn *netstack.GUDPConn, src, target netip.AddrPort, err error) { + log.W("udp: proxy: %v -> %v; err %v", src, target, err) + if !src.IsValid() || !target.IsValid() { + return + } + + realips, domains, probableDomains, blocklists := undoAlg(h.resolver, target.Addr()) + + // flow is alg/nat-aware, do not change target or any addrs + res := h.onFlow(src, target, realips, domains, probableDomains, blocklists) + cid, pid, uid := splitCidPidUid(res) + smm := udpSummary(cid, pid, uid, target.Addr()) + + if h.status.Load() == UDPEND { + err = errUdpEnd + } else if pid == ipn.Block { + err = errUdpFirewalled + } + smm.done(err) } // Proxy implements netstack.GUDPConnHandler; thread-safe. // Must be called from a goroutine. func (h *udpHandler) Proxy(gconn *netstack.GUDPConn, src, dst netip.AddrPort) (ok bool) { defer core.Recover(core.Exit11, "udp.Proxy") - return h.proxy(gconn, src, dst, false) + return h.proxy(gconn, src, dst, nil) } // proxy connects src to dst over a proxy; thread-safe. -func (h *udpHandler) proxy(gconn *netstack.GUDPConn, src, dst netip.AddrPort, mux bool) (ok bool) { - - remote, smm, ct, err := h.Connect(gconn, src, dst, mux) // remote may be nil; smm is never nil +func (h *udpHandler) proxy(gconn *netstack.GUDPConn, src, dst netip.AddrPort, dmx netstack.DemuxerFn) (ok bool) { + mux := dmx != nil + remote, smm, err := h.Connect(gconn, src, dst, dmx) // remote may be nil; smm is never nil if err != nil { - clos(gconn, remote) + core.Close(gconn, remote) queueSummary(h.smmch, h.done, smm.done(err)) // smm may be nil log.W("udp: proxy: mux? %t, unexpected %s -> %s; err: %v", mux, src, dst, err) // dst addrs no longer tracked in h.Connect: h.conntracker.Untrack(ct.CID) @@ -217,23 +232,23 @@ func (h *udpHandler) proxy(gconn *netstack.GUDPConn, src, dst netip.AddrPort, mu cid = smm.ID } - h.conntracker.Track(ct, gconn, remote) + h.conntracker.Track(cid, gconn, remote) core.Go("udp.forward: "+cid, func() { - defer h.conntracker.Untrack(ct.CID) + defer h.conntracker.Untrack(cid) forward(gconn, &rwext{remote}, h.smmch, h.done, smm) }) return true // ok } // Connect connects the proxy server; thread-safe. -func (h *udpHandler) Connect(gconn *netstack.GUDPConn, src, target netip.AddrPort, mux bool) (dst core.UDPConn, smm *SocketSummary, ct core.ConnTuple, err error) { - var px ipn.Proxy = nil - var pc io.Closer = nil - - // connect gconn right away, since we assume a duplex-stream from here on - // see: h.Connect -> dnsOverride - if err = gconn.Establish(); err != nil { - log.W("udp: %s gconn connect, mux? %t, err %s => %s", src, target, mux, err) +func (h *udpHandler) Connect(gconn *netstack.GUDPConn, src, target netip.AddrPort, dmx netstack.DemuxerFn) (pc net.Conn, smm *SocketSummary, err error) { + mux := dmx != nil + + if !target.IsValid() { // must call h.Bind + err = errUdpUnconnected + } else { // connect gconn right away, since we assume a duplex-stream from here on + // see: h.Connect -> dnsOverride + err = gconn.Establish() } // err handled after onFlow, so that the listener knows about this gconn/flow realips, domains, probableDomains, blocklists := undoAlg(h.resolver, target.Addr()) diff --git a/intra/udpmux.go b/intra/udpmux.go index 80f68ebc..ceac108b 100644 --- a/intra/udpmux.go +++ b/intra/udpmux.go @@ -18,6 +18,7 @@ import ( "github.com/celzero/firestack/intra/core" "github.com/celzero/firestack/intra/log" + "github.com/celzero/firestack/intra/netstack" ) // from: github.com/pion/transport/blob/03c807b/udp/conn.go @@ -61,7 +62,8 @@ type muxer struct { dxconns chan *demuxconn // never closed doneCh chan struct{} // stop vending, reading, and routing once sync.Once - cb func() // muxer.stop() callback (new goroutine) + cb func() // muxer.stop() callback (in a new goroutine) + vnd netstack.DemuxerFn // for new routes in netstack rmu sync.Mutex // protects routes routes map[string]*demuxconn // remote addr -> demuxed conn @@ -249,6 +251,11 @@ func (x *muxer) route(raddr net.Addr) (*demuxconn, error) { case x.dxconns <- conn: x.stats.dxcount.Add(1) x.routes[addr] = conn + if dst, err := addr2netip(raddr); err == nil && dst.IsValid() { + go x.vnd(dst) + } else { // should never happen + log.E("udp: mux: %s route: invalid addr %s; err: %v", x.cid, raddr, err) + } log.I("udp: mux: %s route: new for %s; stats: %d", x.cid, raddr, x.stats) } @@ -488,3 +495,7 @@ func (e *muxTable) dissociate(id string, src netip.AddrPort) { defer e.Unlock() delete(e.t, src) } + +func addr2netip(addr net.Addr) (netip.AddrPort, error) { + return netip.ParseAddrPort(addr.String()) +}