Skip to content

Commit

Permalink
Concurrent Websocket Handlers for Log Streaming (#7428)
Browse files Browse the repository at this point in the history
* concurrent conn with mutex

* add buffer size

* only localhost

* clarify caller needs to hold lock

* Update shared/logutil/stream.go

Co-authored-by: Preston Van Loon <preston@prysmaticlabs.com>

* Update shared/logutil/stream.go

Co-authored-by: Preston Van Loon <preston@prysmaticlabs.com>

* fix up tests

Co-authored-by: Preston Van Loon <preston@prysmaticlabs.com>
Co-authored-by: prylabs-bulldozer[bot] <58059840+prylabs-bulldozer[bot]@users.noreply.github.com>
  • Loading branch information
3 people authored Oct 6, 2020
1 parent c0ed43d commit 4d77978
Show file tree
Hide file tree
Showing 3 changed files with 105 additions and 44 deletions.
5 changes: 4 additions & 1 deletion shared/logutil/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -23,5 +23,8 @@ go_test(
name = "go_default_test",
srcs = ["stream_test.go"],
embed = [":go_default_library"],
deps = ["//shared/testutil/require:go_default_library"],
deps = [
"//shared/testutil/require:go_default_library",
"@com_github_sirupsen_logrus//hooks/test:go_default_library",
],
)
112 changes: 71 additions & 41 deletions shared/logutil/stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ package logutil
import (
"io"
"net/http"
"strings"
"sync"

"github.com/gorilla/websocket"
lru "github.com/hashicorp/golang-lru"
Expand All @@ -11,18 +13,36 @@ import (
log "github.com/sirupsen/logrus"
)

// LogCacheSize is the number of log entries to keep in memory for new
// websocket connections.
const LogCacheSize = 20
const (
// LogCacheSize is the number of log entries to keep in memory for new
// websocket connections.
LogCacheSize = 20
// Size for the buffered channel used for receiving log messages. The default
// value should be enough to handle most incoming amount of logs without
// blocking the thread.
logBufferSize = 100
)

// Compile time interface check.
var _ = io.Writer(&StreamServer{})
var (
// Compile time interface check.
_ = io.Writer(&StreamServer{})
streamUpgrader = websocket.Upgrader{
ReadBufferSize: 1024,
WriteBufferSize: 1024,
CheckOrigin: func(r *http.Request) bool {
// Only allow requests from localhost.
return strings.Contains(r.Host, "localhost")
},
}
)

// StreamServer defines a a websocket server which can receive events from
// a feed and write them to open websocket connections.
type StreamServer struct {
feed *event.Feed
cache *lru.Cache
feed *event.Feed
cache *lru.Cache
clients map[*websocket.Conn]bool
lock sync.RWMutex
}

// NewLogStreamServer initializes a new stream server capable of
Expand All @@ -33,19 +53,15 @@ func NewLogStreamServer() *StreamServer {
panic(err) // This can only occur when the LogCacheSize is negative.
}
ss := &StreamServer{
feed: new(event.Feed),
cache: c,
feed: new(event.Feed),
cache: c,
clients: make(map[*websocket.Conn]bool),
}
addLogWriter(ss)
go ss.sendLogsToClients()
return ss
}

var streamUpgrader = websocket.Upgrader{
ReadBufferSize: 1024,
WriteBufferSize: 1024,
CheckOrigin: func(r *http.Request) bool { return true },
}

// Handler for new websocket connections to stream new events received
// via an event feed as they occur.
func (ss *StreamServer) Handler(w http.ResponseWriter, r *http.Request) {
Expand All @@ -54,52 +70,66 @@ func (ss *StreamServer) Handler(w http.ResponseWriter, r *http.Request) {
log.Errorf("Could not write websocket message: %v", err)
return
}
defer func() {
if err := conn.Close(); err != nil {
log.Errorf("Could not close websocket connection: %v", err)
}
}()

ch := make(chan []byte, 1)
defer close(ch)
sub := ss.feed.Subscribe(ch)
defer sub.Unsubscribe()

// Backfill stream with recent messages.
for _, k := range ss.cache.Keys() {
d, ok := ss.cache.Get(k)
if ok {
if err := conn.WriteMessage(websocket.TextMessage, d.([]byte)); err != nil {
log.Errorf("Could not write websocket message: %v", err)
if err := conn.Close(); err != nil {
log.Errorf("Could not close websocket connection: %v", err)
}
return
}
}
}
ss.lock.Lock()
ss.clients[conn] = true
ss.lock.Unlock()
}

// Write a binary message and send over the event feed.
func (ss *StreamServer) Write(p []byte) (n int, err error) {
ss.feed.Send(p)
ss.cache.Add(rand.NewGenerator().Uint64(), p)
return len(p), nil
}

func (ss *StreamServer) sendLogsToClients() {
ch := make(chan []byte, logBufferSize)
defer close(ch)
sub := ss.feed.Subscribe(ch)
defer sub.Unsubscribe()

for {
select {
case evt := <-ch:
if err := conn.WriteMessage(websocket.TextMessage, evt); err != nil {
log.Errorf("Could not write websocket message: %v", err)
return
}
case <-r.Context().Done():
if err := conn.WriteMessage(websocket.CloseNormalClosure, []byte("context canceled")); err != nil {
log.Error(err)
return
ss.lock.Lock()
for conn := range ss.clients {
if err := conn.WriteMessage(websocket.TextMessage, evt); err != nil {
log.WithError(err).Error("Could not write websocket message")
ss.removeClient(conn)
}
}
ss.lock.Unlock()
case err := <-sub.Err():
if err := conn.WriteMessage(websocket.CloseInternalServerErr, []byte(err.Error())); err != nil {
log.Error(err)
return
ss.lock.Lock()
for conn := range ss.clients {
if err := conn.WriteMessage(websocket.CloseInternalServerErr, []byte(err.Error())); err != nil {
log.WithError(err).Error("Could not write websocket message")
}
ss.removeClient(conn)
}
ss.lock.Unlock()
}
}
}

// Write a binary message and send over the event feed.
func (ss *StreamServer) Write(p []byte) (n int, err error) {
ss.feed.Send(p)
ss.cache.Add(rand.NewGenerator().Uint64(), p)
return len(p), nil
// The caller of this function needs to acquire a mutex.
func (ss *StreamServer) removeClient(conn *websocket.Conn) {
delete(ss.clients, conn)
if err := conn.Close(); err != nil {
log.Errorf("Could not close websocket connection: %v", err)
}
}
32 changes: 30 additions & 2 deletions shared/logutil/stream_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"time"

"github.com/prysmaticlabs/prysm/shared/testutil/require"
logTest "github.com/sirupsen/logrus/hooks/test"
)

type fakeAddr int
Expand Down Expand Up @@ -52,6 +53,32 @@ func (resp *testResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
return fakeNetConn{strings.NewReader(""), resp.brw}, rw, nil
}

func TestLogStreamServer_DisallowsNonLocalhostOrigin(t *testing.T) {
hook := logTest.NewGlobal()
ss := NewLogStreamServer()
br := bufio.NewReader(strings.NewReader(""))
buf := new(bytes.Buffer)
bw := bufio.NewWriter(buf)
rw := httptest.NewRecorder()
resp := &testResponseWriter{
brw: bufio.NewReadWriter(br, bw),
ResponseWriter: rw,
}
req := &http.Request{
Method: "GET",
Host: "externalsource",
Header: http.Header{
"Upgrade": []string{"websocket"},
"Connection": []string{"upgrade"},
"Sec-Websocket-Key": []string{"dGhlIHNhbXBsZSBub25jZQ=="},
"Sec-Websocket-Version": []string{"13"},
},
}
ss.Handler(resp, req)
require.NoError(t, resp.brw.Flush())
require.LogsContain(t, hook, "origin not allowed")
}

func TestLogStreamServer_BackfillsMessages(t *testing.T) {
ss := NewLogStreamServer()
msgs := [][]byte{
Expand All @@ -74,6 +101,7 @@ func TestLogStreamServer_BackfillsMessages(t *testing.T) {
}
req := &http.Request{
Method: "GET",
Host: "localhost",
Header: http.Header{
"Upgrade": []string{"websocket"},
"Connection": []string{"upgrade"},
Expand All @@ -82,8 +110,8 @@ func TestLogStreamServer_BackfillsMessages(t *testing.T) {
},
}

go ss.Handler(resp, req)
time.Sleep(50 * time.Millisecond)
ss.Handler(resp, req)
go ss.sendLogsToClients()
require.NoError(t, resp.brw.Flush())
dst, err := ioutil.ReadAll(buf)
require.NoError(t, err)
Expand Down

0 comments on commit 4d77978

Please sign in to comment.