Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow specifically accessing the proxy source address #14

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 42 additions & 13 deletions protocol.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,12 @@ type Listener struct {
// may be speaking the Proxy Protocol. If it is, the RemoteAddr() will
// return the address of the client instead of the proxy address.
type Conn struct {
bufReader *bufio.Reader
conn net.Conn
dstAddr *net.TCPAddr
srcAddr *net.TCPAddr
bufReader *bufio.Reader
conn net.Conn
dstAddr *net.TCPAddr
srcAddr *net.TCPAddr
// Any error encountered while reading the proxyproto header
proxyErr error
useConnAddr bool
once sync.Once
proxyHeaderTimeout time.Duration
Expand Down Expand Up @@ -158,7 +160,7 @@ func (p *Conn) LocalAddr() net.Addr {
// protocol is being used, otherwise just returns the address of
// the socket peer. If there is an error parsing the header, the
// address of the client is not returned, and the socket is closed.
// Once implication of this is that the call could block if the
// One implication of this is that the call could block if the
// client is slow. Using a Deadline is recommended if this is called
// before Read()
func (p *Conn) RemoteAddr() net.Addr {
Expand All @@ -169,6 +171,22 @@ func (p *Conn) RemoteAddr() net.Addr {
return p.conn.RemoteAddr()
}

// ProxySourceAddr returns the source address according to the proxyproto.
// If there was an error parsing the proxy header, that error will be returned.
// This call will read the proxy header if it hasn't been read yet, and thus
// using a Deadline is recommended if this is called before Read().
// This method, if called, can be used to reliably check if the connection is
// using a proxy.
// If UnknownTrue is set on the listener, ProxySourcAddr may return 'nil, nil'
// in the case of a proxy protocol being used with PROXY UNKNOWN.
func (p *Conn) ProxySourceAddr() (net.Addr, error) {
p.checkPrefixOnce()
if p.srcAddr == nil {
return nil, p.proxyErr
}
return p.srcAddr, p.proxyErr
}

func (p *Conn) SetDeadline(t time.Time) error {
return p.conn.SetDeadline(t)
}
Expand Down Expand Up @@ -203,6 +221,7 @@ func (p *Conn) checkPrefix() error {
inp, err := p.bufReader.Peek(i)

if err != nil {
p.proxyErr = fmt.Errorf("error while trying to read proxy header: %w", err)
if neterr, ok := err.(net.Error); ok && neterr.Timeout() {
return nil
} else {
Expand All @@ -212,13 +231,15 @@ func (p *Conn) checkPrefix() error {

// Check for a prefix mis-match, quit early
if !bytes.Equal(inp, prefix[:i]) {
p.proxyErr = fmt.Errorf("connection read did not match proxy header")
return nil
}
}

// Read the header line
header, err := p.bufReader.ReadString('\n')
if err != nil {
p.proxyErr = fmt.Errorf("error reading first proxyheader line: %w", err)
p.conn.Close()
return err
}
Expand All @@ -230,53 +251,61 @@ func (p *Conn) checkPrefix() error {
parts := strings.Split(header, " ")
if len(parts) < 2 {
p.conn.Close()
return fmt.Errorf("Invalid header line: %s", header)
p.proxyErr = fmt.Errorf("invalid header line: %s", header)
return p.proxyErr
}

// Verify the type is known
switch parts[1] {
case "UNKNOWN":
if !p.unknownOK || len(parts) != 2 {
p.conn.Close()
return fmt.Errorf("Invalid UNKNOWN header line: %s", header)
p.proxyErr = fmt.Errorf("invalid UNKNOWN header line: %s", header)
return p.proxyErr
}
p.useConnAddr = true
return nil
case "TCP4":
case "TCP6":
default:
p.conn.Close()
return fmt.Errorf("Unhandled address type: %s", parts[1])
p.proxyErr = fmt.Errorf("Unhandled address type: %s", parts[1])
return p.proxyErr
}

if len(parts) != 6 {
p.conn.Close()
return fmt.Errorf("Invalid header line: %s", header)
p.proxyErr = fmt.Errorf("Invalid header line (should have 6 parts): %s", header)
return p.proxyErr
}

// Parse out the source address
ip := net.ParseIP(parts[2])
if ip == nil {
p.conn.Close()
return fmt.Errorf("Invalid source ip: %s", parts[2])
p.proxyErr = fmt.Errorf("Invalid source ip: %s", parts[2])
return p.proxyErr
}
port, err := strconv.Atoi(parts[4])
if err != nil {
p.conn.Close()
return fmt.Errorf("Invalid source port: %s", parts[4])
p.proxyErr = fmt.Errorf("Invalid source port: %s", parts[4])
return p.proxyErr
}
p.srcAddr = &net.TCPAddr{IP: ip, Port: port}

// Parse out the destination address
ip = net.ParseIP(parts[3])
if ip == nil {
p.conn.Close()
return fmt.Errorf("Invalid destination ip: %s", parts[3])
p.proxyErr = fmt.Errorf("Invalid destination ip: %s", parts[3])
return p.proxyErr
}
port, err = strconv.Atoi(parts[5])
if err != nil {
p.conn.Close()
return fmt.Errorf("Invalid destination port: %s", parts[5])
p.proxyErr = fmt.Errorf("Invalid destination port: %s", parts[5])
return p.proxyErr
}
p.dstAddr = &net.TCPAddr{IP: ip, Port: port}

Expand Down
31 changes: 30 additions & 1 deletion protocol_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,10 @@ func TestPassthrough(t *testing.T) {
if _, err := conn.Write([]byte("pong")); err != nil {
t.Fatalf("err: %v", err)
}

if src, err := conn.(*Conn).ProxySourceAddr(); err == nil {
t.Fatalf("expected error on passthrough, but got nil and src %v", src)
}
}

func TestTimeout(t *testing.T) {
Expand Down Expand Up @@ -185,6 +189,13 @@ func TestParse_ipv4(t *testing.T) {
if addr.Port != 1000 {
t.Fatalf("bad: %v", addr)
}
src, err := conn.(*Conn).ProxySourceAddr()
if err != nil {
t.Fatalf("expected no error on proxy source addr: %v", err)
}
if src != addr {
t.Fatalf("expected addrs to match in working proxy case: %v != %v", src, addr)
}
}

func TestParse_ipv6(t *testing.T) {
Expand Down Expand Up @@ -244,6 +255,13 @@ func TestParse_ipv6(t *testing.T) {
if addr.Port != 1000 {
t.Fatalf("bad: %v", addr)
}
src, err := conn.(*Conn).ProxySourceAddr()
if err != nil {
t.Fatalf("expected no error on proxy source addr: %v", err)
}
if src != addr {
t.Fatalf("expected addrs to match in working proxy case: %v != %v", src, addr)
}
}

func TestParse_Unknown(t *testing.T) {
Expand Down Expand Up @@ -294,7 +312,13 @@ func TestParse_Unknown(t *testing.T) {
if _, err := conn.Write([]byte("pong")); err != nil {
t.Fatalf("err: %v", err)
}

src, err := conn.(*Conn).ProxySourceAddr()
if err != nil {
t.Fatalf("expected no error on proxy source addr for UNKNOWN: %v", err)
}
if src != nil {
t.Fatalf("expected src addr to be nil on UNKNOWN proxy: %v", src)
}
}

func TestParse_BadHeader(t *testing.T) {
Expand Down Expand Up @@ -337,6 +361,11 @@ func TestParse_BadHeader(t *testing.T) {
t.Fatalf("bad: %v", addr)
}

// ProxySourceAddr should return the error
if _, err := conn.(*Conn).ProxySourceAddr(); err == nil {
t.Fatalf("expected an error when the proxy header was wrong")
}

// Read should fail
recv := make([]byte, 4)
_, err = conn.Read(recv)
Expand Down