Skip to content

Commit

Permalink
Add a hook when websocket connection is closed (#149)
Browse files Browse the repository at this point in the history
  • Loading branch information
juliens authored and emilevauge committed Jul 17, 2018
1 parent c81cf8f commit a3ed5f6
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 3 deletions.
21 changes: 18 additions & 3 deletions forward/fwd.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"crypto/tls"
"errors"
"fmt"
"net"
"net/http"
"net/http/httptest"
"net/http/httputil"
Expand Down Expand Up @@ -126,6 +127,14 @@ func StateListener(stateListener UrlForwardingStateListener) optSetter {
}
}

// WebsocketConnectionClosedHook defines a hook called when websocket connection is closed
func WebsocketConnectionClosedHook(hook func(req *http.Request, conn net.Conn)) optSetter {
return func(f *Forwarder) error {
f.httpForwarder.websocketConnectionClosedHook = hook
return nil
}
}

// ResponseModifier defines a response modifier for the HTTP forwarder
func ResponseModifier(responseModifier func(*http.Response) error) optSetter {
return func(f *Forwarder) error {
Expand Down Expand Up @@ -188,7 +197,8 @@ type httpForwarder struct {

log OxyLogger

bufferPool httputil.BufferPool
bufferPool httputil.BufferPool
websocketConnectionClosedHook func(req *http.Request, conn net.Conn)
}

const defaultFlushInterval = time.Duration(100) * time.Millisecond
Expand Down Expand Up @@ -374,8 +384,13 @@ func (f *httpForwarder) serveWebSocket(w http.ResponseWriter, req *http.Request,
log.Errorf("vulcand/oxy/forward/websocket: Error while upgrading connection : %v", err)
return
}
defer underlyingConn.Close()
defer targetConn.Close()
defer func() {
underlyingConn.Close()
targetConn.Close()
if f.websocketConnectionClosedHook != nil {
f.websocketConnectionClosedHook(req, underlyingConn.UnderlyingConn())
}
}()

errClient := make(chan error, 1)
errBackend := make(chan error, 1)
Expand Down
48 changes: 48 additions & 0 deletions forward/fwd_websocket_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,54 @@ func TestWebSocketTCPClose(t *testing.T) {
assert.Equal(t, 1006, wsErr.Code)
}

func TestWebsocketConnectionClosedHook(t *testing.T) {
closed := make(chan struct{})
f, err := New(WebsocketConnectionClosedHook(func(req *http.Request, conn net.Conn) {
close(closed)
}))
require.NoError(t, err)

mux := http.NewServeMux()
mux.Handle("/ws", websocket.Handler(func(conn *websocket.Conn) {
msg := make([]byte, 4)
conn.Read(msg)
conn.Write(msg)
conn.Close()
}))

srv := testutils.NewHandler(func(w http.ResponseWriter, req *http.Request) {
mux.ServeHTTP(w, req)
})
defer srv.Close()

proxy := testutils.NewHandler(func(w http.ResponseWriter, req *http.Request) {
req.URL = testutils.ParseURI(srv.URL)
f.ServeHTTP(w, req)
})
defer proxy.Close()

serverAddr := proxy.Listener.Addr().String()

headers := http.Header{}
webSocketURL := "ws://" + serverAddr + "/ws"
headers.Add("Origin", webSocketURL)

conn, resp, err := gorillawebsocket.DefaultDialer.Dial(webSocketURL, headers)
require.NoError(t, err, "Error during Dial with response: %+v", resp)

conn.WriteMessage(gorillawebsocket.TextMessage, []byte("OK"))
fmt.Println(conn.ReadMessage())

conn.Close()

select {
case <-time.After(time.Second):
t.Errorf("Websocket Hook not called")
case <-closed:

}
}

func TestWebSocketEcho(t *testing.T) {
f, err := New()
require.NoError(t, err)
Expand Down

0 comments on commit a3ed5f6

Please sign in to comment.