Skip to content

Commit

Permalink
feat: stop all active connections on Proxy.Stop() (#206)
Browse files Browse the repository at this point in the history
  • Loading branch information
fortuna authored Apr 2, 2024
1 parent 48d8a48 commit 49ebf7c
Show file tree
Hide file tree
Showing 3 changed files with 80 additions and 7 deletions.
21 changes: 15 additions & 6 deletions x/httpproxy/connect_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,23 +95,32 @@ func (h *connectHandler) ServeHTTP(proxyResp http.ResponseWriter, proxyReq *http
return
}

httpConn, _, err := hijacker.Hijack()
httpConn, clientRW, err := hijacker.Hijack()
if err != nil {
http.Error(proxyResp, "Failed to hijack connection", http.StatusInternalServerError)
return
}
defer httpConn.Close()
// TODO(fortuna): Use context.AfterFunc after we migrate to Go 1.21.
go func() {
// We close the hijacked connection when the context is done. This way
// we allow the HTTP server to control the request lifetime.
// The request context will be cancelled right after ServeHTTP returns,
// but it can be cancelled before, if the server uses a custom BaseContext.
<-proxyReq.Context().Done()
httpConn.Close()
}()

// Inform the client that the connection has been established.
httpConn.Write([]byte("HTTP/1.1 200 Connection established\r\n\r\n"))
clientRW.Write([]byte("HTTP/1.1 200 Connection established\r\n\r\n"))
clientRW.Flush()

// Relay data between client and target in both directions.
go func() {
io.Copy(targetConn, httpConn)
io.Copy(targetConn, clientRW)
targetConn.CloseWrite()
}()
io.Copy(httpConn, targetConn)
// httpConn is closed by the defer httpConn.Close() above.
io.Copy(clientRW, targetConn)
clientRW.Flush()
}

// NewConnectHandler creates a [http.Handler] that handles CONNECT requests and forwards
Expand Down
40 changes: 40 additions & 0 deletions x/httpproxy/connect_handler_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
// Copyright 2024 Jigsaw Operations LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package httpproxy

import (
"context"
"errors"
"net/http/httptest"
"testing"

"github.com/Jigsaw-Code/outline-sdk/transport"
"github.com/stretchr/testify/require"
)

func TestNewConnectHandler(t *testing.T) {
h := NewConnectHandler(transport.FuncStreamDialer(func(ctx context.Context, addr string) (transport.StreamConn, error) {
return nil, errors.New("not implemented")
}))

ch, ok := h.(*connectHandler)
require.True(t, ok)
require.NotNil(t, ch.dialer)

req := httptest.NewRequest("CONNECT", "example.invalid:0", nil)
resp := httptest.NewRecorder()
h.ServeHTTP(resp, req)
require.Equal(t, 503, resp.Result().StatusCode)
}
26 changes: 25 additions & 1 deletion x/mobileproxy/mobileproxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package mobileproxy

import (
"context"
"errors"
"fmt"
"io"
"log"
Expand Down Expand Up @@ -74,6 +75,11 @@ func (p *Proxy) Port() int {
// The function associates the given 'dialer' with the specified 'path', allowing different dialers to be used for
// different path-based proxies within the same application in the future. currently we only support one URL proxy.
func (p *Proxy) AddURLProxy(path string, dialer *StreamDialer) {
if p.proxyHandler == nil {
// Called after Stop. Warn and ignore.
log.Println("Called Proxy.AddURLProxy after Stop")
return
}
if len(path) == 0 || path[0] != '/' {
path = "/" + path
}
Expand All @@ -92,6 +98,9 @@ func (p *Proxy) Stop(timeoutSeconds int) {
log.Fatalf("Failed to shutdown gracefully: %v", err)
p.server.Close()
}
// Allow garbage collection in case the user keeps holding a reference to the Proxy.
p.proxyHandler = nil
p.server = nil
}

// RunProxy runs a local web proxy that listens on localAddress, and handles proxy requests by
Expand All @@ -101,10 +110,25 @@ func RunProxy(localAddress string, dialer *StreamDialer) (*Proxy, error) {
if err != nil {
return nil, fmt.Errorf("could not listen on address %v: %v", localAddress, err)
}
if dialer == nil {
return nil, errors.New("dialer must not be nil. Please create and pass a valid StreamDialer")
}

// The default http.Server doesn't close hijacked connections or cancel in-flight request contexts during
// shutdown. This can lead to lingering connections. We'll create a base context, propagated to requests,
// that is cancelled on shutdown. This enables handlers to gracefully terminate requests and close connections.
serverCtx, cancelCtx := context.WithCancelCause(context.Background())
proxyHandler := httpproxy.NewProxyHandler(dialer)
proxyHandler.FallbackHandler = http.NotFoundHandler()
server := &http.Server{Handler: proxyHandler}
server := &http.Server{
Handler: proxyHandler,
BaseContext: func(l net.Listener) context.Context {
return serverCtx
},
}
server.RegisterOnShutdown(func() {
cancelCtx(errors.New("server stopped"))
})
go server.Serve(listener)

host, portStr, err := net.SplitHostPort(listener.Addr().String())
Expand Down

0 comments on commit 49ebf7c

Please sign in to comment.