Skip to content

Commit

Permalink
proxyproto: PROXY protocol net.Conn and net.Listener impl
Browse files Browse the repository at this point in the history
Use github.com/pires/go-proxyproto header parsing and provide better (safer) net.Conn and net.Listener implementations.
The Conn type is designed after tls.Conn.

We never read more than header bytes from connection.
This allows to eliminate any 3rd party readers,
and to use the connection ReadFrom() and WriteTo() functions if provided.
The underlying connection is available via Conn.NetConn().

On context cancellation the connection is closed to terminate the paring go routine,
and to avoid unspecified behaviour caused by double reads from header parsing and user code.
  • Loading branch information
mmatczuk committed Sep 20, 2024
1 parent 6afe460 commit 54d89e8
Show file tree
Hide file tree
Showing 3 changed files with 180 additions and 0 deletions.
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ require (
github.com/kevinburke/hostsfile v0.0.0-20220522040509-e5e984885321
github.com/mitchellh/go-wordwrap v1.0.1
github.com/mmatczuk/anyflag v0.0.0-20240709090339-eb9e24cd1b44
github.com/pires/go-proxyproto v0.7.0
github.com/prometheus/client_golang v1.20.4
github.com/prometheus/client_model v0.6.1
github.com/prometheus/common v0.59.1
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@ github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ=
github.com/pelletier/go-toml/v2 v2.2.2 h1:aYUidT7k73Pcl9nb2gScu7NSrKCSHIDE89b3+6Wq+LM=
github.com/pelletier/go-toml/v2 v2.2.2/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h4qDAS4n929Rs=
github.com/pires/go-proxyproto v0.7.0 h1:IukmRewDQFWC7kfnb66CSomk2q/seBuilHBYFwyq0Hs=
github.com/pires/go-proxyproto v0.7.0/go.mod h1:Vz/1JPY/OACxWGQNIRY2BeyDmpoaWmEP40O9LbuiFR4=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U=
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
Expand Down
177 changes: 177 additions & 0 deletions proxyproto/proxyproto.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
// Copyright 2022-2024 Sauce Labs Inc., all rights reserved.
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at https://mozilla.org/MPL/2.0/.

package proxyproto

import (
"bufio"
"context"
"fmt"
"io"
"net"
"sync"
"sync/atomic"
"time"

"github.com/pires/go-proxyproto"
)

// Conn wraps a net.Conn and provides access to the proxy protocol header.
// If the header is not present or cannot be read within the timeout,
// the connection is closed.
type Conn struct {
net.Conn

readHeaderTimeout time.Duration
isHeaderRead atomic.Bool
headerMu sync.Mutex
header proxyproto.Header
headerErr error
}

func (c *Conn) NetConn() net.Conn {
return c.Conn
}

func (c *Conn) LocalAddr() net.Addr {
if err := c.readHeader(); err != nil {
return c.Conn.LocalAddr()
}

if c.headerErr != nil || c.header.Command.IsLocal() {
return c.Conn.LocalAddr()
}

return c.header.DestinationAddr
}

func (c *Conn) RemoteAddr() net.Addr {
if err := c.readHeader(); err != nil {
return c.Conn.RemoteAddr()
}

if c.headerErr != nil || c.header.Command.IsLocal() {
return c.Conn.RemoteAddr()
}

return c.header.SourceAddr
}

func (c *Conn) Read(b []byte) (n int, err error) {
if err := c.readHeader(); err != nil {
return 0, err
}
return c.Conn.Read(b)
}

func (c *Conn) Write(b []byte) (n int, err error) {
if err := c.readHeader(); err != nil {
return 0, err
}
return c.Conn.Write(b)
}

func (c *Conn) Header() (proxyproto.Header, error) {
return c.HeaderContext(context.Background())
}

func (c *Conn) HeaderContext(ctx context.Context) (proxyproto.Header, error) {
if err := c.readHeaderContext(ctx); err != nil {
return proxyproto.Header{}, err
}
return c.header, nil
}

func (c *Conn) readHeader() error {
return c.readHeaderContext(context.Background())
}

func (c *Conn) readHeaderContext(ctx context.Context) error {
if c.isHeaderRead.Load() {
return c.headerErr
}

c.headerMu.Lock()
defer c.headerMu.Unlock()

if c.isHeaderRead.Load() {
return c.headerErr
}

t0 := time.Now()
if c.readHeaderTimeout > 0 {
if d, ok := ctx.Deadline(); !ok || d.Sub(t0) > c.readHeaderTimeout {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(ctx, c.readHeaderTimeout)
defer cancel()
}
}

type result struct {
header *proxyproto.Header
err error
}
resCh := make(chan result)

go func() {
// For v1 the header length is at most 108 bytes.
// For v2 the header length is at most 52 bytes plus the length of the TLVs.
// We use 256 bytes to be safe.
const bufSize = 256
// Use a byteReader to read only one byte at a time,
// so we can read the header without consuming more bytes than needed.
// On success, the reader must be empty.
// Otherwise, the connection is closed on timeout or never read on error.
br := bufio.NewReaderSize(byteReader{c.Conn}, bufSize)

var r result
r.header, r.err = proxyproto.Read(br)

if r.err == nil && br.Buffered() > 0 {
panic("proxy protocol header read: unexpected data after header")
}

resCh <- r
}()

select {
case <-ctx.Done():
c.Conn.Close()
c.headerErr = fmt.Errorf("proxy protocol header read timeout: %w", ctx.Err())
case r := <-resCh:
c.header = *r.header
c.headerErr = r.err
}

c.isHeaderRead.Store(true)

return c.headerErr
}

type Listener struct {
net.Listener
ReadHeaderTimeout time.Duration
}

func (l *Listener) Accept() (net.Conn, error) {
c, err := l.Listener.Accept()
if err != nil {
return nil, err
}

return &Conn{
Conn: c,
readHeaderTimeout: l.ReadHeaderTimeout,
}, nil
}

type byteReader struct {
r io.Reader
}

func (r byteReader) Read(p []byte) (int, error) {
return r.r.Read(p[:1])
}

0 comments on commit 54d89e8

Please sign in to comment.