diff --git a/pkg/portfwd/client.go b/pkg/portfwd/client.go index 3285fe83447..0f104fdf13a 100644 --- a/pkg/portfwd/client.go +++ b/pkg/portfwd/client.go @@ -2,20 +2,18 @@ package portfwd import ( "context" - "errors" "fmt" - "io" "net" + "time" + "github.com/containers/gvisor-tap-vsock/pkg/services/forwarder" + "github.com/lima-vm/lima/pkg/bicopy" "github.com/lima-vm/lima/pkg/guestagent/api" guestagentclient "github.com/lima-vm/lima/pkg/guestagent/api/client" "github.com/sirupsen/logrus" - "golang.org/x/sync/errgroup" ) func HandleTCPConnection(ctx context.Context, client *guestagentclient.GuestAgentClient, conn net.Conn, guestAddr string) { - defer conn.Close() - id := fmt.Sprintf("tcp-%s-%s", conn.LocalAddr().String(), conn.RemoteAddr().String()) stream, err := client.Tunnel(ctx) @@ -24,26 +22,17 @@ func HandleTCPConnection(ctx context.Context, client *guestagentclient.GuestAgen return } - g, _ := errgroup.WithContext(ctx) - - rw := &GrpcClientRW{stream: stream, id: id, addr: guestAddr} - g.Go(func() error { - _, err := io.Copy(rw, conn) - return err - }) - g.Go(func() error { - _, err := io.Copy(conn, rw) - return err - }) - - if err := g.Wait(); err != nil { - logrus.Debugf("error in tcp tunnel for id: %s error:%v", id, err) + // Handshake message to start tunnel + if err := stream.Send(&api.TunnelMessage{Id: id, Protocol: "tcp", GuestAddr: guestAddr}); err != nil { + logrus.Errorf("could not start tcp tunnel for id: %s error:%v", id, err) + return } + + rw := &GrpcClientRW{stream: stream, id: id, addr: guestAddr, protocol: "tcp"} + bicopy.Bicopy(rw, conn, nil) } func HandleUDPConnection(ctx context.Context, client *guestagentclient.GuestAgentClient, conn net.PacketConn, guestAddr string) { - defer conn.Close() - id := fmt.Sprintf("udp-%s", conn.LocalAddr().String()) stream, err := client.Tunnel(ctx) @@ -52,79 +41,46 @@ func HandleUDPConnection(ctx context.Context, client *guestagentclient.GuestAgen return } - g, _ := errgroup.WithContext(ctx) - - g.Go(func() error { - buf := make([]byte, 65507) - for { - n, addr, err := conn.ReadFrom(buf) - // We must handle n > 0 bytes before considering the error. - // https://pkg.go.dev/net#PacketConn - if n > 0 { - msg := &api.TunnelMessage{ - Id: id + "-" + addr.String(), - Protocol: "udp", - GuestAddr: guestAddr, - Data: buf[:n], - UdpTargetAddr: addr.String(), - } - if err := stream.Send(msg); err != nil { - return err - } - } - if err != nil { - // https://pkg.go.dev/net#PacketConn does not mention io.EOF semantics. - if errors.Is(err, io.EOF) { - return nil - } - return err - } - } - }) + // Handshake message to start tunnel + if err := stream.Send(&api.TunnelMessage{Id: id, Protocol: "udp", GuestAddr: guestAddr}); err != nil { + logrus.Errorf("could not start udp tunnel for id: %s error:%v", id, err) + return + } - g.Go(func() error { - for { - // Not documented: when err != nil, in is always nil. - in, err := stream.Recv() - if err != nil { - if errors.Is(err, io.EOF) { - return nil - } - return err - } - addr, err := net.ResolveUDPAddr("udp", in.UdpTargetAddr) - if err != nil { - return err - } - _, err = conn.WriteTo(in.Data, addr) - if err != nil { - return err - } - } + proxy, err := forwarder.NewUDPProxy(conn, func() (net.Conn, error) { + rw := &GrpcClientRW{stream: stream, id: id, addr: guestAddr, protocol: "udp"} + return rw, nil }) - - if err := g.Wait(); err != nil { - logrus.Debugf("error in udp tunnel for id: %s error:%v", id, err) + if err != nil { + logrus.Errorf("error in udp tunnel proxy for id: %s error:%v", id, err) + return } + + defer func() { + err := proxy.Close() + if err != nil { + logrus.Errorf("error in closing udp tunnel proxy for id: %s error:%v", id, err) + } + }() + proxy.Run() } type GrpcClientRW struct { - id string - addr string - stream api.GuestService_TunnelClient + id string + addr string + + protocol string + stream api.GuestService_TunnelClient } -var _ io.ReadWriter = (*GrpcClientRW)(nil) +var _ net.Conn = (*GrpcClientRW)(nil) -func (g GrpcClientRW) Write(p []byte) (n int, err error) { - if len(p) == 0 { - return 0, nil - } +func (g *GrpcClientRW) Write(p []byte) (n int, err error) { err = g.stream.Send(&api.TunnelMessage{ Id: g.id, GuestAddr: g.addr, Data: p, - Protocol: "tcp", + Protocol: g.protocol, }) if err != nil { return 0, err @@ -132,18 +88,35 @@ func (g GrpcClientRW) Write(p []byte) (n int, err error) { return len(p), nil } -func (g GrpcClientRW) Read(p []byte) (n int, err error) { - // Not documented: when err != nil, in is always nil. +func (g *GrpcClientRW) Read(p []byte) (n int, err error) { in, err := g.stream.Recv() if err != nil { - if errors.Is(err, io.EOF) { - return 0, nil - } return 0, err } - if len(in.Data) == 0 { - return 0, nil - } copy(p, in.Data) return len(in.Data), nil } + +func (g *GrpcClientRW) Close() error { + return g.stream.CloseSend() +} + +func (g *GrpcClientRW) LocalAddr() net.Addr { + return &net.UnixAddr{Name: "grpc", Net: "unixpacket"} +} + +func (g *GrpcClientRW) RemoteAddr() net.Addr { + return &net.UnixAddr{Name: "grpc", Net: "unixpacket"} +} + +func (g *GrpcClientRW) SetDeadline(_ time.Time) error { + return nil +} + +func (g *GrpcClientRW) SetReadDeadline(_ time.Time) error { + return nil +} + +func (g *GrpcClientRW) SetWriteDeadline(_ time.Time) error { + return nil +} diff --git a/pkg/portfwdserver/server.go b/pkg/portfwdserver/server.go index 314575e55fe..985fbeb3340 100644 --- a/pkg/portfwdserver/server.go +++ b/pkg/portfwdserver/server.go @@ -4,66 +4,79 @@ import ( "errors" "io" "net" + "time" + "github.com/lima-vm/lima/pkg/bicopy" "github.com/lima-vm/lima/pkg/guestagent/api" ) -type TunnelServer struct { - Conns map[string]net.Conn -} +type TunnelServer struct{} func NewTunnelServer() *TunnelServer { - return &TunnelServer{ - Conns: make(map[string]net.Conn), - } + return &TunnelServer{} } func (s *TunnelServer) Start(stream api.GuestService_TunnelServer) error { - for { - in, err := stream.Recv() + // Receive the handshake message to start tunnel + in, err := stream.Recv() + if err != nil { if errors.Is(err, io.EOF) { return nil } - if err != nil { - return err - } - if len(in.Data) == 0 { - continue - } + return err + } - conn, ok := s.Conns[in.Id] - if !ok { - conn, err = net.Dial(in.Protocol, in.GuestAddr) - if err != nil { - return err - } - s.Conns[in.Id] = conn - - writer := &GRPCServerWriter{id: in.Id, udpAddr: in.UdpTargetAddr, stream: stream} - go func() { - _, _ = io.Copy(writer, conn) - delete(s.Conns, writer.id) - }() - } - _, err = conn.Write(in.Data) - if err != nil { - return err - } + // We simply forward data form GRPC stream to net.Conn for both tcp and udp. So simple proxy is sufficient + conn, err := net.Dial(in.Protocol, in.GuestAddr) + if err != nil { + return err } + rw := &GRPCServerRW{stream: stream, id: in.Id} + bicopy.Bicopy(rw, conn, nil) + return nil } -type GRPCServerWriter struct { - id string - udpAddr string - stream api.GuestService_TunnelServer +type GRPCServerRW struct { + id string + stream api.GuestService_TunnelServer } -var _ io.Writer = (*GRPCServerWriter)(nil) +var _ net.Conn = (*GRPCServerRW)(nil) -func (g GRPCServerWriter) Write(p []byte) (n int, err error) { - if len(p) == 0 { - return 0, nil - } - err = g.stream.Send(&api.TunnelMessage{Id: g.id, Data: p, UdpTargetAddr: g.udpAddr}) +func (g *GRPCServerRW) Write(p []byte) (n int, err error) { + err = g.stream.Send(&api.TunnelMessage{Id: g.id, Data: p}) return len(p), err } + +func (g *GRPCServerRW) Read(p []byte) (n int, err error) { + in, err := g.stream.Recv() + if err != nil { + return 0, err + } + copy(p, in.Data) + return len(in.Data), nil +} + +func (g *GRPCServerRW) Close() error { + return nil +} + +func (g *GRPCServerRW) LocalAddr() net.Addr { + return &net.UnixAddr{Name: "grpc", Net: "unixpacket"} +} + +func (g *GRPCServerRW) RemoteAddr() net.Addr { + return &net.UnixAddr{Name: "grpc", Net: "unixpacket"} +} + +func (g *GRPCServerRW) SetDeadline(_ time.Time) error { + return nil +} + +func (g *GRPCServerRW) SetReadDeadline(_ time.Time) error { + return nil +} + +func (g *GRPCServerRW) SetWriteDeadline(_ time.Time) error { + return nil +}