diff --git a/authfe/proxy.go b/authfe/proxy.go index bece95236..4d54f6eda 100644 --- a/authfe/proxy.go +++ b/authfe/proxy.go @@ -1,6 +1,7 @@ package main import ( + "context" "fmt" "io" "net" @@ -31,11 +32,13 @@ func newProxy(cfg proxyConfig) (http.Handler, error) { case "https": fallthrough case "http": + // Make all transformations outside of the director since // they are also required when proxying websockets return &httpProxy{cfg, httputil.ReverseProxy{ - Director: func(*http.Request) {}, - Transport: proxyTransport, + Director: func(*http.Request) {}, + Transport: proxyTransport, + ErrorHandler: handleError, }}, nil case "mock": return &mockProxy{cfg}, nil @@ -163,6 +166,18 @@ func (p *httpProxy) proxyWS(w http.ResponseWriter, r *http.Request) { logger.Debugf("proxy: websocket: connection closed") } +func handleError(rw http.ResponseWriter, outReq *http.Request, err error) { + ctx := outReq.Context() + var header int + if ctx.Err() == context.Canceled { + header = 499 // Client Closed Request (nginx convention) + } else { + user.LogWith(ctx, logging.Global()).WithField("err", err).Errorln("http proxy error") + header = http.StatusBadGateway + } + rw.WriteHeader(header) +} + type closeWriter interface { CloseWrite() error } diff --git a/authfe/proxy_test.go b/authfe/proxy_test.go index b755ff0f1..a49461384 100644 --- a/authfe/proxy_test.go +++ b/authfe/proxy_test.go @@ -216,3 +216,49 @@ func TestProxyGRPCTracing(t *testing.T) { assert.Equal(t, 200, resp.StatusCode) assert.Equal(t, "world", string(body[:c])) } + +func TestProxyClientClosed(t *testing.T) { + serverCh := make(chan interface{}) + + // Set up a slow server + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(418) + serverCh <- nil // Synchonise with the test + })) + defer server.Close() + + // Set up a proxy handler pointing at the server + serverURL, err := url.Parse(server.URL) + assert.NoError(t, err, "Cannot parse URL") + proxyHandler, _ := newProxy(proxyConfig{hostAndPort: serverURL.Host, protocol: "http"}) + + // Intercept the proxy response to check the response code + codeCh := make(chan int) + interceptedProxyHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + recorder := httptest.NewRecorder() + proxyHandler.ServeHTTP(recorder, r) + codeCh <- recorder.Code + w.WriteHeader(recorder.Code) + }) + + // Set up the proxy server + proxyServer := httptest.NewServer(interceptedProxyHandler) + defer proxyServer.Close() + + // Make a request which times out faster than the slow server + req, err := http.NewRequest("GET", proxyServer.URL, nil) + require.NoError(t, err) + ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond) + defer cancel() + req = req.WithContext(ctx) + + _, err = http.DefaultClient.Do(req) + // Check that the request timed out early + assert.Error(t, context.DeadlineExceeded, err) + // Wait for the slow server to receive the request, and allow it to continue + assert.Nil(t, <-serverCh) + // Wait for the proxyHandler to finish handling the request, and allow it to continue + responseCode := <-codeCh + // Check that the proxy server set the response code to 499 + assert.Equal(t, 499, responseCode) +}