diff --git a/websocketproxy.go b/websocketproxy.go index 63d39ba..96ca283 100644 --- a/websocketproxy.go +++ b/websocketproxy.go @@ -23,8 +23,13 @@ var ( // DefaultDialer is a dialer with all fields set to the default zero values. DefaultDialer = websocket.DefaultDialer + + // DefaultReplicator is a simple message passthrough + DefaultReplicator = passthroughReplicator ) +type MessageReplicatorFunc func(dst, src *websocket.Conn, errc chan error) + // WebsocketProxy is an HTTP Handler that takes an incoming WebSocket // connection and proxies it to another server. type WebsocketProxy struct { @@ -33,6 +38,14 @@ type WebsocketProxy struct { // which will be forwarded to another server. Director func(incoming *http.Request, out http.Header) + // IncomeReplicator is a function that forward messages incoming from origin + // into the backend. If nil, passthroughsReplicator is used. + IncomeReplicator MessageReplicatorFunc + + // BackendReplicator is a function that forwards messages from backend into + // origin. If nil, passthroughsReplicator is used. + BackendReplicator MessageReplicatorFunc + // Backend returns the backend URL which the proxy uses to reverse proxy // the incoming WebSocket connection. Request is the initial incoming and // unmodified request. @@ -177,30 +190,18 @@ func (w *WebsocketProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) { errClient := make(chan error, 1) errBackend := make(chan error, 1) - replicateWebsocketConn := func(dst, src *websocket.Conn, errc chan error) { - for { - msgType, msg, err := src.ReadMessage() - if err != nil { - m := websocket.FormatCloseMessage(websocket.CloseNormalClosure, fmt.Sprintf("%v", err)) - if e, ok := err.(*websocket.CloseError); ok { - if e.Code != websocket.CloseNoStatusReceived { - m = websocket.FormatCloseMessage(e.Code, e.Text) - } - } - errc <- err - dst.WriteMessage(websocket.CloseMessage, m) - break - } - err = dst.WriteMessage(msgType, msg) - if err != nil { - errc <- err - break - } - } + + incomingReplicator := w.IncomeReplicator + if w.IncomeReplicator == nil { + incomingReplicator = DefaultReplicator } + go incomingReplicator(connPub, connBackend, errClient) - go replicateWebsocketConn(connPub, connBackend, errClient) - go replicateWebsocketConn(connBackend, connPub, errBackend) + backendReplicator := w.BackendReplicator + if w.BackendReplicator == nil { + backendReplicator = DefaultReplicator + } + go backendReplicator(connBackend, connPub, errBackend) var message string select { @@ -231,3 +232,25 @@ func copyResponse(rw http.ResponseWriter, resp *http.Response) error { _, err := io.Copy(rw, resp.Body) return err } + +func passthroughReplicator(dst, src *websocket.Conn, errc chan error) { + for { + msgType, msg, err := src.ReadMessage() + if err != nil { + m := websocket.FormatCloseMessage(websocket.CloseNormalClosure, fmt.Sprintf("%v", err)) + if e, ok := err.(*websocket.CloseError); ok { + if e.Code != websocket.CloseNoStatusReceived { + m = websocket.FormatCloseMessage(e.Code, e.Text) + } + } + errc <- err + dst.WriteMessage(websocket.CloseMessage, m) + break + } + err = dst.WriteMessage(msgType, msg) + if err != nil { + errc <- err + break + } + } +}