Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 45 additions & 22 deletions websocketproxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,13 @@ var (

// DefaultDialer is a dialer with all fields set to the default zero values.
DefaultDialer = websocket.DefaultDialer

// DefaultReplicator is a simple message passthrough
DefaultReplicator = passthroughReplicator
)

type MessageReplicatorFunc func(dst, src *websocket.Conn, errc chan error)

// WebsocketProxy is an HTTP Handler that takes an incoming WebSocket
// connection and proxies it to another server.
type WebsocketProxy struct {
Expand All @@ -33,6 +38,14 @@ type WebsocketProxy struct {
// which will be forwarded to another server.
Director func(incoming *http.Request, out http.Header)

// IncomeReplicator is a function that forward messages incoming from origin
// into the backend. If nil, passthroughsReplicator is used.
IncomeReplicator MessageReplicatorFunc

// BackendReplicator is a function that forwards messages from backend into
// origin. If nil, passthroughsReplicator is used.
BackendReplicator MessageReplicatorFunc

// Backend returns the backend URL which the proxy uses to reverse proxy
// the incoming WebSocket connection. Request is the initial incoming and
// unmodified request.
Expand Down Expand Up @@ -177,30 +190,18 @@ func (w *WebsocketProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {

errClient := make(chan error, 1)
errBackend := make(chan error, 1)
replicateWebsocketConn := func(dst, src *websocket.Conn, errc chan error) {
for {
msgType, msg, err := src.ReadMessage()
if err != nil {
m := websocket.FormatCloseMessage(websocket.CloseNormalClosure, fmt.Sprintf("%v", err))
if e, ok := err.(*websocket.CloseError); ok {
if e.Code != websocket.CloseNoStatusReceived {
m = websocket.FormatCloseMessage(e.Code, e.Text)
}
}
errc <- err
dst.WriteMessage(websocket.CloseMessage, m)
break
}
err = dst.WriteMessage(msgType, msg)
if err != nil {
errc <- err
break
}
}

incomingReplicator := w.IncomeReplicator
if w.IncomeReplicator == nil {
incomingReplicator = DefaultReplicator
}
go incomingReplicator(connPub, connBackend, errClient)

go replicateWebsocketConn(connPub, connBackend, errClient)
go replicateWebsocketConn(connBackend, connPub, errBackend)
backendReplicator := w.BackendReplicator
if w.BackendReplicator == nil {
backendReplicator = DefaultReplicator
}
go backendReplicator(connBackend, connPub, errBackend)

var message string
select {
Expand Down Expand Up @@ -231,3 +232,25 @@ func copyResponse(rw http.ResponseWriter, resp *http.Response) error {
_, err := io.Copy(rw, resp.Body)
return err
}

func passthroughReplicator(dst, src *websocket.Conn, errc chan error) {
for {
msgType, msg, err := src.ReadMessage()
if err != nil {
m := websocket.FormatCloseMessage(websocket.CloseNormalClosure, fmt.Sprintf("%v", err))
if e, ok := err.(*websocket.CloseError); ok {
if e.Code != websocket.CloseNoStatusReceived {
m = websocket.FormatCloseMessage(e.Code, e.Text)
}
}
errc <- err
dst.WriteMessage(websocket.CloseMessage, m)
break
}
err = dst.WriteMessage(msgType, msg)
if err != nil {
errc <- err
break
}
}
}