Skip to content

Commit

Permalink
Merge pull request #2985 from balajiv113/grpc-pw
Browse files Browse the repository at this point in the history
Revamp GRPC Port forwarding tunnels to use existing proxy
  • Loading branch information
AkihiroSuda authored Dec 24, 2024
2 parents 5148ffd + 17b2d58 commit 5fb9353
Show file tree
Hide file tree
Showing 2 changed files with 117 additions and 131 deletions.
151 changes: 62 additions & 89 deletions pkg/portfwd/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,18 @@ package portfwd

import (
"context"
"errors"
"fmt"
"io"
"net"
"time"

"github.com/containers/gvisor-tap-vsock/pkg/services/forwarder"
"github.com/lima-vm/lima/pkg/bicopy"
"github.com/lima-vm/lima/pkg/guestagent/api"
guestagentclient "github.com/lima-vm/lima/pkg/guestagent/api/client"
"github.com/sirupsen/logrus"
"golang.org/x/sync/errgroup"
)

func HandleTCPConnection(ctx context.Context, client *guestagentclient.GuestAgentClient, conn net.Conn, guestAddr string) {
defer conn.Close()

id := fmt.Sprintf("tcp-%s-%s", conn.LocalAddr().String(), conn.RemoteAddr().String())

stream, err := client.Tunnel(ctx)
Expand All @@ -24,26 +22,17 @@ func HandleTCPConnection(ctx context.Context, client *guestagentclient.GuestAgen
return
}

g, _ := errgroup.WithContext(ctx)

rw := &GrpcClientRW{stream: stream, id: id, addr: guestAddr}
g.Go(func() error {
_, err := io.Copy(rw, conn)
return err
})
g.Go(func() error {
_, err := io.Copy(conn, rw)
return err
})

if err := g.Wait(); err != nil {
logrus.Debugf("error in tcp tunnel for id: %s error:%v", id, err)
// Handshake message to start tunnel
if err := stream.Send(&api.TunnelMessage{Id: id, Protocol: "tcp", GuestAddr: guestAddr}); err != nil {
logrus.Errorf("could not start tcp tunnel for id: %s error:%v", id, err)
return
}

rw := &GrpcClientRW{stream: stream, id: id, addr: guestAddr, protocol: "tcp"}
bicopy.Bicopy(rw, conn, nil)
}

func HandleUDPConnection(ctx context.Context, client *guestagentclient.GuestAgentClient, conn net.PacketConn, guestAddr string) {
defer conn.Close()

id := fmt.Sprintf("udp-%s", conn.LocalAddr().String())

stream, err := client.Tunnel(ctx)
Expand All @@ -52,98 +41,82 @@ func HandleUDPConnection(ctx context.Context, client *guestagentclient.GuestAgen
return
}

g, _ := errgroup.WithContext(ctx)

g.Go(func() error {
buf := make([]byte, 65507)
for {
n, addr, err := conn.ReadFrom(buf)
// We must handle n > 0 bytes before considering the error.
// https://pkg.go.dev/net#PacketConn
if n > 0 {
msg := &api.TunnelMessage{
Id: id + "-" + addr.String(),
Protocol: "udp",
GuestAddr: guestAddr,
Data: buf[:n],
UdpTargetAddr: addr.String(),
}
if err := stream.Send(msg); err != nil {
return err
}
}
if err != nil {
// https://pkg.go.dev/net#PacketConn does not mention io.EOF semantics.
if errors.Is(err, io.EOF) {
return nil
}
return err
}
}
})
// Handshake message to start tunnel
if err := stream.Send(&api.TunnelMessage{Id: id, Protocol: "udp", GuestAddr: guestAddr}); err != nil {
logrus.Errorf("could not start udp tunnel for id: %s error:%v", id, err)
return
}

g.Go(func() error {
for {
// Not documented: when err != nil, in is always nil.
in, err := stream.Recv()
if err != nil {
if errors.Is(err, io.EOF) {
return nil
}
return err
}
addr, err := net.ResolveUDPAddr("udp", in.UdpTargetAddr)
if err != nil {
return err
}
_, err = conn.WriteTo(in.Data, addr)
if err != nil {
return err
}
}
proxy, err := forwarder.NewUDPProxy(conn, func() (net.Conn, error) {
rw := &GrpcClientRW{stream: stream, id: id, addr: guestAddr, protocol: "udp"}
return rw, nil
})

if err := g.Wait(); err != nil {
logrus.Debugf("error in udp tunnel for id: %s error:%v", id, err)
if err != nil {
logrus.Errorf("error in udp tunnel proxy for id: %s error:%v", id, err)
return
}

defer func() {
err := proxy.Close()
if err != nil {
logrus.Errorf("error in closing udp tunnel proxy for id: %s error:%v", id, err)
}
}()
proxy.Run()
}

type GrpcClientRW struct {
id string
addr string
stream api.GuestService_TunnelClient
id string
addr string

protocol string
stream api.GuestService_TunnelClient
}

var _ io.ReadWriter = (*GrpcClientRW)(nil)
var _ net.Conn = (*GrpcClientRW)(nil)

func (g GrpcClientRW) Write(p []byte) (n int, err error) {
if len(p) == 0 {
return 0, nil
}
func (g *GrpcClientRW) Write(p []byte) (n int, err error) {
err = g.stream.Send(&api.TunnelMessage{
Id: g.id,
GuestAddr: g.addr,
Data: p,
Protocol: "tcp",
Protocol: g.protocol,
})
if err != nil {
return 0, err
}
return len(p), nil
}

func (g GrpcClientRW) Read(p []byte) (n int, err error) {
// Not documented: when err != nil, in is always nil.
func (g *GrpcClientRW) Read(p []byte) (n int, err error) {
in, err := g.stream.Recv()
if err != nil {
if errors.Is(err, io.EOF) {
return 0, nil
}
return 0, err
}
if len(in.Data) == 0 {
return 0, nil
}
copy(p, in.Data)
return len(in.Data), nil
}

func (g *GrpcClientRW) Close() error {
return g.stream.CloseSend()
}

func (g *GrpcClientRW) LocalAddr() net.Addr {
return &net.UnixAddr{Name: "grpc", Net: "unixpacket"}
}

func (g *GrpcClientRW) RemoteAddr() net.Addr {
return &net.UnixAddr{Name: "grpc", Net: "unixpacket"}
}

func (g *GrpcClientRW) SetDeadline(_ time.Time) error {
return nil
}

func (g *GrpcClientRW) SetReadDeadline(_ time.Time) error {
return nil
}

func (g *GrpcClientRW) SetWriteDeadline(_ time.Time) error {
return nil
}
97 changes: 55 additions & 42 deletions pkg/portfwdserver/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,66 +4,79 @@ import (
"errors"
"io"
"net"
"time"

"github.com/lima-vm/lima/pkg/bicopy"
"github.com/lima-vm/lima/pkg/guestagent/api"
)

type TunnelServer struct {
Conns map[string]net.Conn
}
type TunnelServer struct{}

func NewTunnelServer() *TunnelServer {
return &TunnelServer{
Conns: make(map[string]net.Conn),
}
return &TunnelServer{}
}

func (s *TunnelServer) Start(stream api.GuestService_TunnelServer) error {
for {
in, err := stream.Recv()
// Receive the handshake message to start tunnel
in, err := stream.Recv()
if err != nil {
if errors.Is(err, io.EOF) {
return nil
}
if err != nil {
return err
}
if len(in.Data) == 0 {
continue
}
return err
}

conn, ok := s.Conns[in.Id]
if !ok {
conn, err = net.Dial(in.Protocol, in.GuestAddr)
if err != nil {
return err
}
s.Conns[in.Id] = conn

writer := &GRPCServerWriter{id: in.Id, udpAddr: in.UdpTargetAddr, stream: stream}
go func() {
_, _ = io.Copy(writer, conn)
delete(s.Conns, writer.id)
}()
}
_, err = conn.Write(in.Data)
if err != nil {
return err
}
// We simply forward data form GRPC stream to net.Conn for both tcp and udp. So simple proxy is sufficient
conn, err := net.Dial(in.Protocol, in.GuestAddr)
if err != nil {
return err
}
rw := &GRPCServerRW{stream: stream, id: in.Id}
bicopy.Bicopy(rw, conn, nil)
return nil
}

type GRPCServerWriter struct {
id string
udpAddr string
stream api.GuestService_TunnelServer
type GRPCServerRW struct {
id string
stream api.GuestService_TunnelServer
}

var _ io.Writer = (*GRPCServerWriter)(nil)
var _ net.Conn = (*GRPCServerRW)(nil)

func (g GRPCServerWriter) Write(p []byte) (n int, err error) {
if len(p) == 0 {
return 0, nil
}
err = g.stream.Send(&api.TunnelMessage{Id: g.id, Data: p, UdpTargetAddr: g.udpAddr})
func (g *GRPCServerRW) Write(p []byte) (n int, err error) {
err = g.stream.Send(&api.TunnelMessage{Id: g.id, Data: p})
return len(p), err
}

func (g *GRPCServerRW) Read(p []byte) (n int, err error) {
in, err := g.stream.Recv()
if err != nil {
return 0, err
}
copy(p, in.Data)
return len(in.Data), nil
}

func (g *GRPCServerRW) Close() error {
return nil
}

func (g *GRPCServerRW) LocalAddr() net.Addr {
return &net.UnixAddr{Name: "grpc", Net: "unixpacket"}
}

func (g *GRPCServerRW) RemoteAddr() net.Addr {
return &net.UnixAddr{Name: "grpc", Net: "unixpacket"}
}

func (g *GRPCServerRW) SetDeadline(_ time.Time) error {
return nil
}

func (g *GRPCServerRW) SetReadDeadline(_ time.Time) error {
return nil
}

func (g *GRPCServerRW) SetWriteDeadline(_ time.Time) error {
return nil
}

0 comments on commit 5fb9353

Please sign in to comment.