Skip to content

Commit

Permalink
proxy: Fixed support for TLS verification of WebSocket connections
Browse files Browse the repository at this point in the history
  • Loading branch information
lhecker committed Dec 28, 2016
1 parent 153d4a5 commit b857265
Showing 1 changed file with 45 additions and 24 deletions.
69 changes: 45 additions & 24 deletions caddyhttp/proxy/reverseproxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -349,9 +349,14 @@ func newConnHijackerTransport(base http.RoundTripper) *connHijackerTransport {
MaxIdleConnsPerHost: -1,
}
if b, _ := base.(*http.Transport); b != nil {
tlsClientConfig := b.TLSClientConfig
if tlsClientConfig.NextProtos != nil {
tlsClientConfig = cloneTLSClientConfig(tlsClientConfig)
tlsClientConfig.NextProtos = nil
}

t.Proxy = b.Proxy
t.TLSClientConfig = cloneTLSClientConfig(b.TLSClientConfig)
t.TLSClientConfig.NextProtos = nil
t.TLSClientConfig = tlsClientConfig
t.TLSHandshakeTimeout = b.TLSHandshakeTimeout
t.Dial = b.Dial
t.DialTLS = b.DialTLS
Expand All @@ -363,19 +368,15 @@ func newConnHijackerTransport(base http.RoundTripper) *connHijackerTransport {

dial := getTransportDial(t)
dialTLS := getTransportDialTLS(t)

t.Dial = func(network, addr string) (net.Conn, error) {
c, err := dial(network, addr)
hj.Conn = c
return &hijackedConn{c, hj}, err
}

if dialTLS != nil {
t.DialTLS = func(network, addr string) (net.Conn, error) {
c, err := dialTLS(network, addr)
hj.Conn = c
return &hijackedConn{c, hj}, err
}
t.DialTLS = func(network, addr string) (net.Conn, error) {
c, err := dialTLS(network, addr)
hj.Conn = c
return &hijackedConn{c, hj}, err
}

return hj
Expand All @@ -390,27 +391,35 @@ func getTransportDial(t *http.Transport) func(network, addr string) (net.Conn, e
return defaultDialer.Dial
}

// getTransportDial returns a TLS Dialer if TLSClientConfig is non-nil
// getTransportDial always returns a TLS Dialer
// and defaults to the existing t.DialTLS.
func getTransportDialTLS(t *http.Transport) func(network, addr string) (net.Conn, error) {
if t.DialTLS != nil {
return t.DialTLS
}
if t.TLSClientConfig == nil {
return nil
}

// newConnHijackerTransport will modify t.Dial after calling this method
// => Create a backup reference.
plainDial := getTransportDial(t)

// The following DialTLS implementation stems from the Go stdlib and
// is identical to what happens if DialTLS is not provided.
// Source: https://github.com/golang/go/blob/230a376b5a67f0e9341e1fa47e670ff762213c83/src/net/http/transport.go#L1018-L1051
return func(network, addr string) (net.Conn, error) {
plainConn, err := plainDial(network, addr)
if err != nil {
return nil, err
}

tlsConn := tls.Client(plainConn, t.TLSClientConfig)
tlsClientConfig := t.TLSClientConfig
if tlsClientConfig == nil {
tlsClientConfig = &tls.Config{}
}
if !tlsClientConfig.InsecureSkipVerify && tlsClientConfig.ServerName == "" {
tlsClientConfig.ServerName = stripPort(addr)
}

tlsConn := tls.Client(plainConn, tlsClientConfig)
errc := make(chan error, 2)
var timer *time.Timer
if d := t.TLSHandshakeTimeout; d != 0 {
Expand All @@ -429,16 +438,12 @@ func getTransportDialTLS(t *http.Transport) func(network, addr string) (net.Conn
plainConn.Close()
return nil, err
}
if !t.TLSClientConfig.InsecureSkipVerify {
serverName := t.TLSClientConfig.ServerName
if serverName == "" {
serverName = addr
idx := strings.LastIndex(serverName, ":")
if idx != -1 {
serverName = serverName[:idx]
}
if !tlsClientConfig.InsecureSkipVerify {
hostname := tlsClientConfig.ServerName
if hostname == "" {
hostname = stripPort(addr)
}
if err := tlsConn.VerifyHostname(serverName); err != nil {
if err := tlsConn.VerifyHostname(hostname); err != nil {
plainConn.Close()
return nil, err
}
Expand All @@ -448,6 +453,22 @@ func getTransportDialTLS(t *http.Transport) func(network, addr string) (net.Conn
}
}

// stripPort returns address without its port if it has one and
// works with IP addresses as well as hostnames formatted as host:port.
//
// IPv6 addresses (excluding the port) must be enclosed in
// square brackets similar to the requirements of Go's stdlib.
func stripPort(address string) string {
// Keep in mind that the address might be a IPv6 address
// and thus contain a colon, but not have a port.
portIdx := strings.LastIndex(address, ":")
ipv6Idx := strings.LastIndex(address, "]")
if portIdx > ipv6Idx {
address = address[:portIdx]
}
return address
}

type tlsHandshakeTimeoutError struct{}

func (tlsHandshakeTimeoutError) Timeout() bool { return true }
Expand Down

0 comments on commit b857265

Please sign in to comment.