Skip to content

Commit

Permalink
fix: send proxy protocol and tests
Browse files Browse the repository at this point in the history
  • Loading branch information
haveachin committed Feb 5, 2024
1 parent 2fecd94 commit 9fcca8d
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 10 deletions.
11 changes: 7 additions & 4 deletions cmd/infrared/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -120,16 +120,19 @@ func run() error {
case sig := <-sigChan:
log.Info().Msg("Received " + sig.String())
case err := <-errChan:
if errors.Is(err, ir.ErrNoServers) {
switch {
case errors.Is(err, ir.ErrNoServers):
log.Fatal().
Str("docs", "https://infrared.dev/config/proxies").
Msg("No proxy configs found; Check the docs")
} else if errors.Is(err, ir.ErrNoTrustedCIDRs) {
case errors.Is(err, ir.ErrNoTrustedCIDRs):
log.Fatal().
Str("docs", "https://infrared.dev/features/proxy-protocol#receive-proxy-protocol").
Msg("Receive PROXY Protocol enabled, but no CIDRs specified; Check the docs")
} else if err != nil {
return err
default:
if err != nil {
return err
}
}
}

Expand Down
1 change: 1 addition & 0 deletions pkg/infrared/infrared.go
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,7 @@ func (ir *Infrared) handleConn(c *clientConn) error {
c.reqDomain = ServerDomain(reqDomain)

resp, err := ir.sr.RequestServer(ServerRequest{
ClientAddr: c.RemoteAddr(),
Domain: c.reqDomain,
IsLogin: c.handshake.IsLoginRequest(),
ProtocolVersion: protocol.Version(c.handshake.ProtocolVersion),
Expand Down
18 changes: 17 additions & 1 deletion pkg/infrared/infrared_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ func TestInfrared_SendProxyProtocol_False(t *testing.T) {
func TestInfrared_ReceiveProxyProtocol_True(t *testing.T) {
cfg := ir.NewConfig().
WithProxyProtocolReceive(true).
WithProxyProtocolTrustedCIDRs()
WithProxyProtocolTrustedCIDRs("127.0.0.1/32")

vi, _ := NewVirtualInfrared(cfg, false)
vc := vi.NewConn()
Expand All @@ -206,3 +206,19 @@ func TestInfrared_ReceiveProxyProtocol_True(t *testing.T) {
t.Fatal(err)
}
}

func TestInfrared_ReceiveProxyProtocol_False(t *testing.T) {
cfg := ir.NewConfig().
WithProxyProtocolReceive(false).
WithProxyProtocolTrustedCIDRs("127.0.0.1/32")

vi, _ := NewVirtualInfrared(cfg, false)
vc := vi.NewConn()
if err := vc.SendProxyProtocolHeader(); err != nil {
t.Fatal(err)
}
if err := vc.SendHandshake(handshaking.ServerBoundHandshake{}); err != nil {
return
}
t.Fatal("no disconnect after invalid proxy protocol header")
}
20 changes: 15 additions & 5 deletions pkg/infrared/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ func (s Server) Dial() (*ServerConn, error) {
}

type ServerRequest struct {
ClientAddr net.Addr
Domain ServerDomain
IsLogin bool
ProtocolVersion protocol.Version
Expand Down Expand Up @@ -192,7 +193,7 @@ func (r DialServerResponder) respondeToStatusRequest(req ServerRequest, srv *Ser
r.respProvs[srv] = respProv
}

_, pk, err := respProv.StatusResponse(req.ProtocolVersion, req.ReadPackets)
_, pk, err := respProv.StatusResponse(req.ClientAddr, req.ProtocolVersion, req.ReadPackets)
if err != nil {
return ServerResponse{}, err
}
Expand All @@ -203,7 +204,7 @@ func (r DialServerResponder) respondeToStatusRequest(req ServerRequest, srv *Ser
}

type StatusResponseProvider interface {
StatusResponse(protocol.Version, [2]protocol.Packet) (status.ResponseJSON, protocol.Packet, error)
StatusResponse(net.Addr, protocol.Version, [2]protocol.Packet) (status.ResponseJSON, protocol.Packet, error)
}

type statusCacheEntry struct {
Expand All @@ -226,13 +227,20 @@ type statusResponseProvider struct {
}

func (s *statusResponseProvider) requestNewStatusResponseJSON(
cliAddr net.Addr,
readPks [2]protocol.Packet,
) (status.ResponseJSON, protocol.Packet, error) {
rc, err := s.server.Dial()
if err != nil {
return status.ResponseJSON{}, protocol.Packet{}, err
}

if s.server.cfg.SendProxyProtocol {
if err := writeProxyProtocolHeader(cliAddr, rc); err != nil {
return status.ResponseJSON{}, protocol.Packet{}, err
}
}

if err := rc.WritePackets(readPks[0], readPks[1]); err != nil {
return status.ResponseJSON{}, protocol.Packet{}, err
}
Expand All @@ -257,11 +265,12 @@ func (s *statusResponseProvider) requestNewStatusResponseJSON(
}

func (s *statusResponseProvider) StatusResponse(
cliAddr net.Addr,
protVer protocol.Version,
readPks [2]protocol.Packet,
) (status.ResponseJSON, protocol.Packet, error) {
if s.cacheTTL <= 0 {
return s.requestNewStatusResponseJSON(readPks)
return s.requestNewStatusResponseJSON(cliAddr, readPks)
}

// Prunes all expired status reponses
Expand All @@ -273,17 +282,18 @@ func (s *statusResponseProvider) StatusResponse(
hash, okHash := s.statusHash[protVer]
entry, okCache := s.statusResponseCache[hash]
if !okHash || !okCache {
return s.cacheResponse(protVer, readPks)
return s.cacheResponse(cliAddr, protVer, readPks)
}

return entry.responseJSON, entry.responsePk, nil
}

func (s *statusResponseProvider) cacheResponse(
cliAddr net.Addr,
protVer protocol.Version,
readPks [2]protocol.Packet,
) (status.ResponseJSON, protocol.Packet, error) {
newStatusResp, pk, err := s.requestNewStatusResponseJSON(readPks)
newStatusResp, pk, err := s.requestNewStatusResponseJSON(cliAddr, readPks)
if err != nil {
return status.ResponseJSON{}, protocol.Packet{}, err
}
Expand Down

0 comments on commit 9fcca8d

Please sign in to comment.