diff --git a/x/httpproxy/connect_handler.go b/x/httpproxy/connect_handler.go index c88bc56d..92e636f8 100644 --- a/x/httpproxy/connect_handler.go +++ b/x/httpproxy/connect_handler.go @@ -95,23 +95,32 @@ func (h *connectHandler) ServeHTTP(proxyResp http.ResponseWriter, proxyReq *http return } - httpConn, _, err := hijacker.Hijack() + httpConn, clientRW, err := hijacker.Hijack() if err != nil { http.Error(proxyResp, "Failed to hijack connection", http.StatusInternalServerError) return } - defer httpConn.Close() + // TODO(fortuna): Use context.AfterFunc after we migrate to Go 1.21. + go func() { + // We close the hijacked connection when the context is done. This way + // we allow the HTTP server to control the request lifetime. + // The request context will be cancelled right after ServeHTTP returns, + // but it can be cancelled before, if the server uses a custom BaseContext. + <-proxyReq.Context().Done() + httpConn.Close() + }() // Inform the client that the connection has been established. - httpConn.Write([]byte("HTTP/1.1 200 Connection established\r\n\r\n")) + clientRW.Write([]byte("HTTP/1.1 200 Connection established\r\n\r\n")) + clientRW.Flush() // Relay data between client and target in both directions. go func() { - io.Copy(targetConn, httpConn) + io.Copy(targetConn, clientRW) targetConn.CloseWrite() }() - io.Copy(httpConn, targetConn) - // httpConn is closed by the defer httpConn.Close() above. + io.Copy(clientRW, targetConn) + clientRW.Flush() } // NewConnectHandler creates a [http.Handler] that handles CONNECT requests and forwards diff --git a/x/httpproxy/connect_handler_test.go b/x/httpproxy/connect_handler_test.go new file mode 100644 index 00000000..a2a019f8 --- /dev/null +++ b/x/httpproxy/connect_handler_test.go @@ -0,0 +1,40 @@ +// Copyright 2024 Jigsaw Operations LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package httpproxy + +import ( + "context" + "errors" + "net/http/httptest" + "testing" + + "github.com/Jigsaw-Code/outline-sdk/transport" + "github.com/stretchr/testify/require" +) + +func TestNewConnectHandler(t *testing.T) { + h := NewConnectHandler(transport.FuncStreamDialer(func(ctx context.Context, addr string) (transport.StreamConn, error) { + return nil, errors.New("not implemented") + })) + + ch, ok := h.(*connectHandler) + require.True(t, ok) + require.NotNil(t, ch.dialer) + + req := httptest.NewRequest("CONNECT", "example.invalid:0", nil) + resp := httptest.NewRecorder() + h.ServeHTTP(resp, req) + require.Equal(t, 503, resp.Result().StatusCode) +} diff --git a/x/mobileproxy/mobileproxy.go b/x/mobileproxy/mobileproxy.go index 745345a3..072da66f 100644 --- a/x/mobileproxy/mobileproxy.go +++ b/x/mobileproxy/mobileproxy.go @@ -20,6 +20,7 @@ package mobileproxy import ( "context" + "errors" "fmt" "io" "log" @@ -74,6 +75,11 @@ func (p *Proxy) Port() int { // The function associates the given 'dialer' with the specified 'path', allowing different dialers to be used for // different path-based proxies within the same application in the future. currently we only support one URL proxy. func (p *Proxy) AddURLProxy(path string, dialer *StreamDialer) { + if p.proxyHandler == nil { + // Called after Stop. Warn and ignore. + log.Println("Called Proxy.AddURLProxy after Stop") + return + } if len(path) == 0 || path[0] != '/' { path = "/" + path } @@ -92,6 +98,9 @@ func (p *Proxy) Stop(timeoutSeconds int) { log.Fatalf("Failed to shutdown gracefully: %v", err) p.server.Close() } + // Allow garbage collection in case the user keeps holding a reference to the Proxy. + p.proxyHandler = nil + p.server = nil } // RunProxy runs a local web proxy that listens on localAddress, and handles proxy requests by @@ -101,10 +110,25 @@ func RunProxy(localAddress string, dialer *StreamDialer) (*Proxy, error) { if err != nil { return nil, fmt.Errorf("could not listen on address %v: %v", localAddress, err) } + if dialer == nil { + return nil, errors.New("dialer must not be nil. Please create and pass a valid StreamDialer") + } + // The default http.Server doesn't close hijacked connections or cancel in-flight request contexts during + // shutdown. This can lead to lingering connections. We'll create a base context, propagated to requests, + // that is cancelled on shutdown. This enables handlers to gracefully terminate requests and close connections. + serverCtx, cancelCtx := context.WithCancelCause(context.Background()) proxyHandler := httpproxy.NewProxyHandler(dialer) proxyHandler.FallbackHandler = http.NotFoundHandler() - server := &http.Server{Handler: proxyHandler} + server := &http.Server{ + Handler: proxyHandler, + BaseContext: func(l net.Listener) context.Context { + return serverCtx + }, + } + server.RegisterOnShutdown(func() { + cancelCtx(errors.New("server stopped")) + }) go server.Serve(listener) host, portStr, err := net.SplitHostPort(listener.Addr().String())