Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support TLS connection to HTTP Proxy #950

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 35 additions & 26 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ import (
"net/url"
"strings"
"time"

"golang.org/x/net/proxy"
)

// ErrBadHandshake is returned when the server response to opening handshake is
Expand Down Expand Up @@ -244,18 +246,39 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h
defer cancel()
}

var netDial netDialerFunc
switch {
case u.Scheme == "https" && d.NetDialTLSContext != nil:
netDial = d.NetDialTLSContext
case d.NetDialContext != nil:
netDial = d.NetDialContext
case d.NetDial != nil:
netDial = func(ctx context.Context, net, addr string) (net.Conn, error) {
return d.NetDial(net, addr)
netDial := newNetDialerFunc(u.Scheme, d.NetDial, d.NetDialContext, d.NetDialTLSContext)

// If needed, wrap the dial function to connect through a proxy.
if d.Proxy != nil {
proxyURL, err := d.Proxy(req)
if err != nil {
return nil, nil, err
}
if proxyURL != nil {
forwardDial := newNetDialerFunc(proxyURL.Scheme, d.NetDial, d.NetDialContext, d.NetDialTLSContext)
if proxyURL.Scheme == "https" && d.NetDialTLSContext == nil {
tlsClientConfig := cloneTLSConfig(d.TLSClientConfig)
if tlsClientConfig.ServerName == "" {
_, hostNoPort := hostPortNoPort(proxyURL)
tlsClientConfig.ServerName = hostNoPort
}
netDial = newHTTPProxyDialerFunc(proxyURL, forwardDial, tlsClientConfig)
} else if proxyURL.Scheme == "http" || proxyURL.Scheme == "https" {
netDial = newHTTPProxyDialerFunc(proxyURL, forwardDial, nil)
} else {
dialer, err := proxy.FromURL(proxyURL, forwardDial)
if err != nil {
return nil, nil, err
}
if d, ok := dialer.(proxy.ContextDialer); ok {
netDial = d.DialContext
} else {
netDial = func(ctx context.Context, net, addr string) (net.Conn, error) {
return dialer.Dial(net, addr)
}
}
}
}
default:
netDial = (&net.Dialer{}).DialContext
}

// If needed, wrap the dial function to set the connection deadline.
Expand All @@ -275,20 +298,6 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h
}
}

// If needed, wrap the dial function to connect through a proxy.
if d.Proxy != nil {
proxyURL, err := d.Proxy(req)
if err != nil {
return nil, nil, err
}
if proxyURL != nil {
netDial, err = proxyFromURL(proxyURL, netDial)
if err != nil {
return nil, nil, err
}
}
}
Comment on lines -278 to -290
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also need to set a deadline for the dialing to proxy, so moved the line up.


hostPort, hostNoPort := hostPortNoPort(u)
trace := httptrace.ContextClientTrace(ctx)
if trace != nil && trace.GetConn != nil {
Expand Down Expand Up @@ -359,7 +368,7 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h
if proto != "http/1.1" {
return nil, nil, fmt.Errorf(
"websocket: protocol %q was given but is not supported;"+
"sharing tls.Config with net/http Transport can cause this error: %w",
"sharing tlsServerName.Config with net/http Transport can cause this error: %w",
proto, err,
)
}
Expand Down
146 changes: 145 additions & 1 deletion client_server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,51 @@ func newTLSServer(t *testing.T) *cstServer {
return &s
}

type cstProxyServer struct{}

func (s *cstProxyServer) ServeHTTP(w http.ResponseWriter, req *http.Request) {
if req.Method != http.MethodConnect {
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
return
}

conn, _, err := w.(http.Hijacker).Hijack()
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
defer conn.Close()

upstream, err := (&net.Dialer{}).DialContext(req.Context(), "tcp", req.URL.Host)
if err != nil {
_, _ = fmt.Fprintf(conn, "HTTP/1.1 502 Bad Gateway\r\n\r\n")
return
}
defer upstream.Close()

_, _ = fmt.Fprintf(conn, "HTTP/1.1 200 Connection established\r\n\r\n")

wg := sync.WaitGroup{}
wg.Add(2)
Copy link

@adrianosela adrianosela Jul 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggestion: I think you probably want to exit as soon as one io.Copy hits EOF...

Here's two ways I think you can achieve this:

	var wg sync.WaitGroup
	defer wg.Wait()

	wg.Add(1)
	go func() {
		defer wg.Done()

		// abort blocked reads from upstream when done reading from conn
		defer upstream.SetDeadline(time.Now().Add(-1 * time.Hour))

		_, _ = io.Copy(upstream, conn)
	}()

	wg.Add(1)
	go func() {
		defer wg.Done()

		// abort blocked reads from conn when done reading from upstream
		defer conn.SetDeadline(time.Now().Add(-1 * time.Hour))

		_, _ = io.Copy(conn, upstream)
	}()

OR more simple - block on an unbuffered channel:

        done := make(chan struct{})

	go func() {
		_, _ = io.Copy(upstream, conn)
                done <- struct{}{}
	}()

	go func() {
		_, _ = io.Copy(conn, upstream)
               done <- struct{}{}
	}()

        <- done

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I prefer the below. I used wg because It's just a test helper function.

I'll change to the below one.

go func() {
defer wg.Done()
_, _ = io.Copy(upstream, conn)
}()
go func() {
defer wg.Done()
_, _ = io.Copy(conn, upstream)
}()
wg.Wait()
}

func newProxyServer() *httptest.Server {
return httptest.NewServer(&cstProxyServer{})
}

func newTLSProxyServer() *httptest.Server {
return httptest.NewTLSServer(&cstProxyServer{})
}

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Previously, only had a temporary proxy implementation, and couldn't test the TLS proxy, so implemented it.

func (t cstHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// Because tests wait for a response from a server, we are guaranteed that
// the wait group count is incremented before the test waits on the group
Expand Down Expand Up @@ -165,7 +210,6 @@ func sendRecv(t *testing.T, ws *Conn) {
}

func TestProxyDial(t *testing.T) {

s := newServer(t)
defer s.Close()

Expand Down Expand Up @@ -202,6 +246,106 @@ func TestProxyDial(t *testing.T) {
sendRecv(t, ws)
}

func TestProxyDialer(t *testing.T) {
testcases := []struct {
name string
isTLS bool
tlsServerName string
insecureSkipVerify bool
netDialTLSContext func(ctx context.Context, network, addr string) (net.Conn, error)
}{{
name: "http",
isTLS: false,
}, {
name: "https",
isTLS: true,
}, {
name: "https with ServerName",
isTLS: true,
tlsServerName: "example.com",
}, {
name: "https with insecureSkipVerify",
isTLS: true,
insecureSkipVerify: true,
}, {
name: "https with netDialTLSContext",
isTLS: true,
netDialTLSContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
dialer := &tls.Dialer{
Config: &tls.Config{
InsecureSkipVerify: true,
},
}
return dialer.DialContext(ctx, network, addr)
},
}}

for _, tc := range testcases {
t.Run(tc.name, func(tt *testing.T) {
s := newServer(tt)
defer s.Close()

var ps *httptest.Server
if tc.isTLS {
ps = newTLSProxyServer()
} else {
ps = newProxyServer()
}

psurl, _ := url.Parse(ps.URL)

netDialCalled := false

cstDialer := cstDialer // make local copy for modification on next line.
cstDialer.Proxy = http.ProxyURL(psurl)
if tc.isTLS {
cstDialer.TLSClientConfig = &tls.Config{
RootCAs: rootCAs(tt, ps),
ServerName: tc.tlsServerName,
InsecureSkipVerify: tc.insecureSkipVerify,
}
if tc.netDialTLSContext != nil {
cstDialer.NetDialTLSContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
netDialCalled = true
return tc.netDialTLSContext(ctx, network, addr)
}
} else {
netDialCalled = true
}
} else {
netDialCalled = true
}

connect := false
origHandler := ps.Config.Handler

// Capture the request Host header.
ps.Config.Handler = http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) {
if r.Method == http.MethodConnect {
connect = true
}

origHandler.ServeHTTP(w, r)
})

ws, _, err := cstDialer.Dial(s.URL, nil)
if err != nil {
tt.Fatalf("Dial: %v", err)
}
defer ws.Close()
sendRecv(tt, ws)

if !connect {
tt.Error("connect not received")
}
if !netDialCalled {
tt.Error("netDialTLSContext not called")
}
})
}
}

func TestProxyAuthorizationDial(t *testing.T) {
s := newServer(t)
defer s.Close()
Expand Down
Loading