diff --git a/forward/fwd.go b/forward/fwd.go index ec4bea59..337d5eff 100644 --- a/forward/fwd.go +++ b/forward/fwd.go @@ -7,6 +7,7 @@ import ( "crypto/tls" "errors" "fmt" + "net" "net/http" "net/http/httptest" "net/http/httputil" @@ -126,6 +127,14 @@ func StateListener(stateListener UrlForwardingStateListener) optSetter { } } +// WebsocketConnectionClosedHook defines a hook called when websocket connection is closed +func WebsocketConnectionClosedHook(hook func(req *http.Request, conn net.Conn)) optSetter { + return func(f *Forwarder) error { + f.httpForwarder.websocketConnectionClosedHook = hook + return nil + } +} + // ResponseModifier defines a response modifier for the HTTP forwarder func ResponseModifier(responseModifier func(*http.Response) error) optSetter { return func(f *Forwarder) error { @@ -188,7 +197,8 @@ type httpForwarder struct { log OxyLogger - bufferPool httputil.BufferPool + bufferPool httputil.BufferPool + websocketConnectionClosedHook func(req *http.Request, conn net.Conn) } const defaultFlushInterval = time.Duration(100) * time.Millisecond @@ -374,8 +384,13 @@ func (f *httpForwarder) serveWebSocket(w http.ResponseWriter, req *http.Request, log.Errorf("vulcand/oxy/forward/websocket: Error while upgrading connection : %v", err) return } - defer underlyingConn.Close() - defer targetConn.Close() + defer func() { + underlyingConn.Close() + targetConn.Close() + if f.websocketConnectionClosedHook != nil { + f.websocketConnectionClosedHook(req, underlyingConn.UnderlyingConn()) + } + }() errClient := make(chan error, 1) errBackend := make(chan error, 1) diff --git a/forward/fwd_websocket_test.go b/forward/fwd_websocket_test.go index 4700df76..91f47d52 100644 --- a/forward/fwd_websocket_test.go +++ b/forward/fwd_websocket_test.go @@ -57,6 +57,54 @@ func TestWebSocketTCPClose(t *testing.T) { assert.Equal(t, 1006, wsErr.Code) } +func TestWebsocketConnectionClosedHook(t *testing.T) { + closed := make(chan struct{}) + f, err := New(WebsocketConnectionClosedHook(func(req *http.Request, conn net.Conn) { + close(closed) + })) + require.NoError(t, err) + + mux := http.NewServeMux() + mux.Handle("/ws", websocket.Handler(func(conn *websocket.Conn) { + msg := make([]byte, 4) + conn.Read(msg) + conn.Write(msg) + conn.Close() + })) + + srv := testutils.NewHandler(func(w http.ResponseWriter, req *http.Request) { + mux.ServeHTTP(w, req) + }) + defer srv.Close() + + proxy := testutils.NewHandler(func(w http.ResponseWriter, req *http.Request) { + req.URL = testutils.ParseURI(srv.URL) + f.ServeHTTP(w, req) + }) + defer proxy.Close() + + serverAddr := proxy.Listener.Addr().String() + + headers := http.Header{} + webSocketURL := "ws://" + serverAddr + "/ws" + headers.Add("Origin", webSocketURL) + + conn, resp, err := gorillawebsocket.DefaultDialer.Dial(webSocketURL, headers) + require.NoError(t, err, "Error during Dial with response: %+v", resp) + + conn.WriteMessage(gorillawebsocket.TextMessage, []byte("OK")) + fmt.Println(conn.ReadMessage()) + + conn.Close() + + select { + case <-time.After(time.Second): + t.Errorf("Websocket Hook not called") + case <-closed: + + } +} + func TestWebSocketEcho(t *testing.T) { f, err := New() require.NoError(t, err)