From 193494ca82ee1d28226de5100240507e645c3a0d Mon Sep 17 00:00:00 2001 From: naison <895703375@qq.com> Date: Mon, 18 Nov 2024 10:47:12 +0000 Subject: [PATCH] refactor: refactor code --- pkg/core/tunhandler.go | 105 ++++----------------------- pkg/core/tunhandlerclient.go | 4 +- pkg/daemon/handler/ssh.go | 2 +- pkg/util/net.go | 135 +++++++++++++++++++++++------------ 4 files changed, 109 insertions(+), 137 deletions(-) diff --git a/pkg/core/tunhandler.go b/pkg/core/tunhandler.go index a8443477..c1dbaa7b 100644 --- a/pkg/core/tunhandler.go +++ b/pkg/core/tunhandler.go @@ -2,15 +2,10 @@ package core import ( "context" - "fmt" - "math" - "math/rand" "net" "sync" "time" - "github.com/google/gopacket" - "github.com/google/gopacket/layers" log "github.com/sirupsen/logrus" "github.com/wencaiwulue/kubevpn/v2/pkg/config" @@ -62,14 +57,6 @@ func (n *RouteMap) RouteTo(ip net.IP) net.Addr { return n.routes[ip.String()] } -func (n *RouteMap) Range(f func(key string, value net.Addr)) { - n.lock.RLock() - defer n.lock.RUnlock() - for k, v := range n.routes { - f(k, v) - } -} - // TunHandler creates a handler for tun tunnel. func TunHandler(chain *Chain, node *Node) Handler { return &tunHandler{ @@ -89,19 +76,6 @@ func (h *tunHandler) Handle(ctx context.Context, tun net.Conn) { } } -func (h *tunHandler) printRoute(ctx context.Context) { - ticker := time.NewTicker(time.Second * 5) - defer ticker.Stop() - for ctx.Err() == nil { - select { - case <-ticker.C: - h.routeMapUDP.Range(func(key string, value net.Addr) { - log.Debugf("To: %s, route: %s", key, value.String()) - }) - } - } -} - type Device struct { tun net.Conn @@ -163,12 +137,12 @@ func (d *Device) Close() { } func heartbeats(ctx context.Context, tun net.Conn) { - conn, err := util.GetTunDeviceByConn(tun) + tunIfi, err := util.GetTunDeviceByConn(tun) if err != nil { log.Errorf("Failed to get tun device: %s", err.Error()) return } - srcIPv4, srcIPv6, err := util.GetLocalTunIP(conn.Name) + srcIPv4, srcIPv6, err := util.GetTunDeviceIP(tunIfi.Name) if err != nil { return } @@ -178,10 +152,17 @@ func heartbeats(ctx context.Context, tun net.Conn) { if config.RouterIP6.To4().Equal(srcIPv6) { return } + if config.DockerRouterIP.To4().Equal(srcIPv4) { + return + } var dstIPv4, dstIPv6 = net.IPv4zero, net.IPv6zero if config.CIDR.Contains(srcIPv4) { - dstIPv4, dstIPv6 = config.RouterIP, config.RouterIP6 - } else if config.DockerCIDR.Contains(srcIPv4) { + dstIPv4 = config.RouterIP + } + if config.CIDR6.Contains(srcIPv6) { + dstIPv6 = config.RouterIP6 + } + if config.DockerCIDR.Contains(srcIPv4) { dstIPv4 = config.DockerRouterIP } @@ -198,69 +179,15 @@ func heartbeats(ctx context.Context, tun net.Conn) { var src, dst net.IP src, dst = srcIPv4, dstIPv4 if !dst.IsUnspecified() { - _, _ = util.Ping(ctx, src.String(), dst.String()) + go util.Ping(ctx, src.String(), dst.String()) } src, dst = srcIPv6, dstIPv6 if !dst.IsUnspecified() { - _, _ = util.Ping(ctx, src.String(), dst.String()) + go util.Ping(ctx, src.String(), dst.String()) } } } -func genICMPPacket(src net.IP, dst net.IP) ([]byte, error) { - buf := gopacket.NewSerializeBuffer() - var id uint16 - for _, b := range src { - id += uint16(b) - } - icmpLayer := layers.ICMPv4{ - TypeCode: layers.CreateICMPv4TypeCode(layers.ICMPv4TypeEchoRequest, 0), - Id: id, - Seq: uint16(rand.Intn(math.MaxUint16 + 1)), - } - ipLayer := layers.IPv4{ - Version: 4, - SrcIP: src, - DstIP: dst, - Protocol: layers.IPProtocolICMPv4, - Flags: layers.IPv4DontFragment, - TTL: 64, - IHL: 5, - Id: uint16(rand.Intn(math.MaxUint16 + 1)), - } - opts := gopacket.SerializeOptions{ - FixLengths: true, - ComputeChecksums: true, - } - err := gopacket.SerializeLayers(buf, opts, &ipLayer, &icmpLayer) - if err != nil { - return nil, fmt.Errorf("failed to serialize icmp packet, err: %v", err) - } - return buf.Bytes(), nil -} - -func genICMPPacketIPv6(src net.IP, dst net.IP) ([]byte, error) { - buf := gopacket.NewSerializeBuffer() - icmpLayer := layers.ICMPv6{ - TypeCode: layers.CreateICMPv6TypeCode(layers.ICMPv6TypeEchoRequest, 0), - } - ipLayer := layers.IPv6{ - Version: 6, - SrcIP: src, - DstIP: dst, - NextHeader: layers.IPProtocolICMPv6, - HopLimit: 255, - } - opts := gopacket.SerializeOptions{ - FixLengths: true, - } - err := gopacket.SerializeLayers(buf, opts, &ipLayer, &icmpLayer) - if err != nil { - return nil, fmt.Errorf("failed to serialize icmp6 packet, err: %v", err) - } - return buf.Bytes(), nil -} - func (d *Device) Start(ctx context.Context) { go d.readFromTun() go d.tunInboundHandler(d.tunInbound, d.tunOutbound) @@ -281,8 +208,6 @@ func (d *Device) SetTunInboundHandler(handler func(tunInbound <-chan *DataElem, } func (h *tunHandler) HandleServer(ctx context.Context, tun net.Conn) { - go h.printRoute(ctx) - device := &Device{ tun: tun, tunInbound: make(chan *DataElem, MaxSize), @@ -296,7 +221,7 @@ func (h *tunHandler) HandleServer(ctx context.Context, tun net.Conn) { log.Errorf("[UDP] Failed to listen %s: %v", h.node.Addr, err) return } - err = transportTun(ctx, tunInbound, tunOutbound, packetConn, h.routeMapUDP, h.routeMapTCP) + err = transportTunServer(ctx, tunInbound, tunOutbound, packetConn, h.routeMapUDP, h.routeMapTCP) if err != nil { log.Errorf("[TUN] %s: %v", tun.LocalAddr(), err) } @@ -480,7 +405,7 @@ func (p *Peer) Close() { p.conn.Close() } -func transportTun(ctx context.Context, tunInbound <-chan *DataElem, tunOutbound chan<- *DataElem, packetConn net.PacketConn, routeMapUDP *RouteMap, routeMapTCP *sync.Map) error { +func transportTunServer(ctx context.Context, tunInbound <-chan *DataElem, tunOutbound chan<- *DataElem, packetConn net.PacketConn, routeMapUDP *RouteMap, routeMapTCP *sync.Map) error { p := &Peer{ conn: packetConn, connInbound: make(chan *udpElem, MaxSize), diff --git a/pkg/core/tunhandlerclient.go b/pkg/core/tunhandlerclient.go index 2e98b300..de8dc228 100644 --- a/pkg/core/tunhandlerclient.go +++ b/pkg/core/tunhandlerclient.go @@ -89,7 +89,7 @@ func transportTunClient(ctx context.Context, tunInbound <-chan *DataElem, tunOut _, err := packetConn.WriteTo(e.data[:e.length], remoteAddr) config.LPool.Put(e.data[:]) if err != nil { - errChan <- errors.Wrap(err, fmt.Sprintf("failed to write packet to remote %s", remoteAddr)) + util.SafeWrite(errChan, errors.Wrap(err, fmt.Sprintf("failed to write packet to remote %s", remoteAddr))) return } } @@ -100,7 +100,7 @@ func transportTunClient(ctx context.Context, tunInbound <-chan *DataElem, tunOut b := config.LPool.Get().([]byte)[:] n, _, err := packetConn.ReadFrom(b[:]) if err != nil { - errChan <- errors.Wrap(err, fmt.Sprintf("failed to read packet from remote %s", remoteAddr)) + util.SafeWrite(errChan, errors.Wrap(err, fmt.Sprintf("failed to read packet from remote %s", remoteAddr))) return } util.SafeWrite(tunOutbound, &DataElem{data: b[:], length: n}) diff --git a/pkg/daemon/handler/ssh.go b/pkg/daemon/handler/ssh.go index c688b0fa..63dc9c76 100644 --- a/pkg/daemon/handler/ssh.go +++ b/pkg/daemon/handler/ssh.go @@ -126,7 +126,7 @@ func (w *wsHandler) handle(c context.Context) { log.Info("Connected tunnel") go func() { for ctx.Err() == nil { - _, _ = util.Ping(ctx, clientIP.IP.String(), ip.String()) + util.Ping(ctx, clientIP.IP.String(), ip.String()) time.Sleep(time.Second * 5) } }() diff --git a/pkg/util/net.go b/pkg/util/net.go index 5039cd3d..f2bf6dd0 100644 --- a/pkg/util/net.go +++ b/pkg/util/net.go @@ -4,12 +4,15 @@ import ( "context" "errors" "fmt" + "math" + "math/rand" "net" - "strings" "time" "github.com/cilium/ipam/service/allocator" "github.com/cilium/ipam/service/ipallocator" + "github.com/google/gopacket" + "github.com/google/gopacket/layers" "github.com/prometheus-community/pro-bing" "golang.org/x/net/ipv4" "golang.org/x/net/ipv6" @@ -23,14 +26,16 @@ func GetTunDevice(ips ...net.IP) (*net.Interface, error) { return nil, err } for _, i := range interfaces { - addrs, err := i.Addrs() + addrList, err := i.Addrs() if err != nil { return nil, err } - for _, addr := range addrs { - for _, ip := range ips { - if strings.Contains(addr.String(), ip.String()) { - return &i, nil + for _, addr := range addrList { + if ipNet, ok := addr.(*net.IPNet); ok { + for _, ip := range ips { + if ipNet.IP.Equal(ip) { + return &i, nil + } } } } @@ -39,53 +44,39 @@ func GetTunDevice(ips ...net.IP) (*net.Interface, error) { } func GetTunDeviceByConn(tun net.Conn) (*net.Interface, error) { - interfaces, err := net.Interfaces() - if err != nil { - return nil, err - } - var ip string - if tunIP, ok := tun.LocalAddr().(*net.IPNet); ok { - ip = tunIP.IP.String() - } else { - ip = tun.LocalAddr().String() - } - for _, i := range interfaces { - addrs, err := i.Addrs() - if err != nil { - return nil, err - } - for _, addr := range addrs { - if strings.Contains(addr.String(), tun.LocalAddr().String()) { - return &i, nil - } - } + var ip net.IP + switch tun.LocalAddr().(type) { + case *net.IPNet: + ip = tun.LocalAddr().(*net.IPNet).IP + case *net.IPAddr: + ip = tun.LocalAddr().(*net.IPAddr).IP } - return nil, fmt.Errorf("can not found any interface with ip %v", ip) + return GetTunDevice(ip) } -func GetLocalTunIP(tunName string) (net.IP, net.IP, error) { - tunIface, err := net.InterfaceByName(tunName) +func GetTunDeviceIP(tunName string) (net.IP, net.IP, error) { + tunIfi, err := net.InterfaceByName(tunName) if err != nil { return nil, nil, err } - addrs, err := tunIface.Addrs() + addrList, err := tunIfi.Addrs() if err != nil { return nil, nil, err } var srcIPv4, srcIPv6 net.IP - for _, addr := range addrs { - ip, _, err := net.ParseCIDR(addr.String()) - if err != nil { - continue - } - if ip.To4() != nil { - srcIPv4 = ip - } else { - srcIPv6 = ip + for _, addr := range addrList { + if ipNet, ok := addr.(*net.IPNet); ok { + if config.CIDR.Contains(ipNet.IP) || config.CIDR6.Contains(ipNet.IP) || config.DockerCIDR.Contains(ipNet.IP) { + if ipNet.IP.To4() != nil { + srcIPv4 = ipNet.IP + } else { + srcIPv6 = ipNet.IP + } + } } } - if srcIPv4 == nil || srcIPv6 == nil { - return srcIPv4, srcIPv6, fmt.Errorf("not found all ip") + if srcIPv4 == nil && srcIPv6 == nil { + return srcIPv4, srcIPv6, fmt.Errorf("can not found any ip") } return srcIPv4, srcIPv6, nil } @@ -99,7 +90,8 @@ func PingOnce(ctx context.Context, srcIP, dstIP string) (bool, error) { pinger.SetLogger(nil) pinger.SetPrivileged(true) pinger.Count = 1 - pinger.Timeout = time.Millisecond * 1000 + pinger.Timeout = time.Second * 1 + pinger.ResolveTimeout = time.Second * 1 err = pinger.RunWithContext(ctx) // Blocks until finished. if err != nil { return false, err @@ -116,8 +108,9 @@ func Ping(ctx context.Context, srcIP, dstIP string) (bool, error) { pinger.Source = srcIP pinger.SetLogger(nil) pinger.SetPrivileged(true) - pinger.Count = 3 - pinger.Timeout = time.Millisecond * 1500 + pinger.Count = 4 + pinger.Timeout = time.Second * 4 + pinger.ResolveTimeout = time.Second * 1 err = pinger.RunWithContext(ctx) // Blocks until finished. if err != nil { return false, err @@ -193,3 +186,57 @@ func getIP(addr net.Addr) net.IP { } return ip } + +func GenICMPPacket(src net.IP, dst net.IP) ([]byte, error) { + buf := gopacket.NewSerializeBuffer() + var id uint16 + for _, b := range src { + id += uint16(b) + } + icmpLayer := layers.ICMPv4{ + TypeCode: layers.CreateICMPv4TypeCode(layers.ICMPv4TypeEchoRequest, 0), + Id: id, + Seq: uint16(rand.Intn(math.MaxUint16 + 1)), + } + ipLayer := layers.IPv4{ + Version: 4, + SrcIP: src, + DstIP: dst, + Protocol: layers.IPProtocolICMPv4, + Flags: layers.IPv4DontFragment, + TTL: 64, + IHL: 5, + Id: uint16(rand.Intn(math.MaxUint16 + 1)), + } + opts := gopacket.SerializeOptions{ + FixLengths: true, + ComputeChecksums: true, + } + err := gopacket.SerializeLayers(buf, opts, &ipLayer, &icmpLayer) + if err != nil { + return nil, fmt.Errorf("failed to serialize icmp packet, err: %v", err) + } + return buf.Bytes(), nil +} + +func GenICMPPacketIPv6(src net.IP, dst net.IP) ([]byte, error) { + buf := gopacket.NewSerializeBuffer() + icmpLayer := layers.ICMPv6{ + TypeCode: layers.CreateICMPv6TypeCode(layers.ICMPv6TypeEchoRequest, 0), + } + ipLayer := layers.IPv6{ + Version: 6, + SrcIP: src, + DstIP: dst, + NextHeader: layers.IPProtocolICMPv6, + HopLimit: 255, + } + opts := gopacket.SerializeOptions{ + FixLengths: true, + } + err := gopacket.SerializeLayers(buf, opts, &ipLayer, &icmpLayer) + if err != nil { + return nil, fmt.Errorf("failed to serialize icmp6 packet, err: %v", err) + } + return buf.Bytes(), nil +}