From 52cf3a155dcce1316505761cc876b76b7b38a830 Mon Sep 17 00:00:00 2001 From: kasey <489222+kasey@users.noreply.github.com> Date: Thu, 24 Oct 2024 14:16:17 -0500 Subject: [PATCH] Safe StreamEvents write loop (#14557) * new type for tests where errors are only logged * StreamHandler waits for write loop exit * add test case for writer timeout * add changelog * add missing file * logging fix * fix logging test to allow info logs * naming/comments; make response controller private * simplify cancel defers * fix typo in test file name --------- Co-authored-by: Kasey Kirkham --- CHANGELOG.md | 1 + beacon-chain/rpc/eth/events/BUILD.bazel | 2 + beacon-chain/rpc/eth/events/events.go | 148 +++++++++++---- beacon-chain/rpc/eth/events/events_test.go | 202 +++++++++++++++------ beacon-chain/rpc/eth/events/http_test.go | 48 ++++- beacon-chain/rpc/eth/events/log.go | 6 + beacon-chain/rpc/eth/events/server.go | 1 + testing/util/BUILD.bazel | 5 + testing/util/logging.go | 90 +++++++++ testing/util/logging_test.go | 78 ++++++++ 10 files changed, 486 insertions(+), 95 deletions(-) create mode 100644 beacon-chain/rpc/eth/events/log.go create mode 100644 testing/util/logging.go create mode 100644 testing/util/logging_test.go diff --git a/CHANGELOG.md b/CHANGELOG.md index 48e15ab28b0a..58f96e2ece12 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -42,6 +42,7 @@ The format is based on Keep a Changelog, and this project adheres to Semantic Ve - Fixed mesh size by appending `gParams.Dhi = gossipSubDhi` - Fix skipping partial withdrawals count. +- wait for the async StreamEvent writer to exit before leaving the http handler, avoiding race condition panics [pr](https://github.com/prysmaticlabs/prysm/pull/14557) ### Security diff --git a/beacon-chain/rpc/eth/events/BUILD.bazel b/beacon-chain/rpc/eth/events/BUILD.bazel index caa311b6b515..ad6b09eb30bd 100644 --- a/beacon-chain/rpc/eth/events/BUILD.bazel +++ b/beacon-chain/rpc/eth/events/BUILD.bazel @@ -4,6 +4,7 @@ go_library( name = "go_default_library", srcs = [ "events.go", + "log.go", "server.go", ], importpath = "github.com/prysmaticlabs/prysm/v5/beacon-chain/rpc/eth/events", @@ -58,5 +59,6 @@ go_test( "//testing/util:go_default_library", "@com_github_ethereum_go_ethereum//common:go_default_library", "@com_github_r3labs_sse_v2//:go_default_library", + "@com_github_sirupsen_logrus//:go_default_library", ], ) diff --git a/beacon-chain/rpc/eth/events/events.go b/beacon-chain/rpc/eth/events/events.go index 7698d2fb9017..92071fda9d2a 100644 --- a/beacon-chain/rpc/eth/events/events.go +++ b/beacon-chain/rpc/eth/events/events.go @@ -28,7 +28,6 @@ import ( eth "github.com/prysmaticlabs/prysm/v5/proto/prysm/v1alpha1" "github.com/prysmaticlabs/prysm/v5/runtime/version" "github.com/prysmaticlabs/prysm/v5/time/slots" - log "github.com/sirupsen/logrus" ) const DefaultEventFeedDepth = 1000 @@ -74,13 +73,6 @@ var ( errWriterUnusable = errors.New("http response writer is unusable") ) -// StreamingResponseWriter defines a type that can be used by the eventStreamer. -// This must be an http.ResponseWriter that supports flushing and hijacking. -type StreamingResponseWriter interface { - http.ResponseWriter - http.Flusher -} - // The eventStreamer uses lazyReaders to defer serialization until the moment the value is ready to be written to the client. type lazyReader func() io.Reader @@ -150,6 +142,7 @@ func newTopicRequest(topics []string) (*topicRequest, error) { // Servers may send SSE comments beginning with ':' for any purpose, // including to keep the event stream connection alive in the presence of proxy servers. func (s *Server) StreamEvents(w http.ResponseWriter, r *http.Request) { + log.Debug("Starting StreamEvents handler") ctx, span := trace.StartSpan(r.Context(), "events.StreamEvents") defer span.End() @@ -159,47 +152,51 @@ func (s *Server) StreamEvents(w http.ResponseWriter, r *http.Request) { return } - sw, ok := w.(StreamingResponseWriter) - if !ok { - msg := "beacon node misconfiguration: http stack may not support required response handling features, like flushing" - httputil.HandleError(w, msg, http.StatusInternalServerError) - return + timeout := s.EventWriteTimeout + if timeout == 0 { + timeout = time.Duration(params.BeaconConfig().SecondsPerSlot) * time.Second } - depth := s.EventFeedDepth - if depth == 0 { - depth = DefaultEventFeedDepth + ka := s.KeepAliveInterval + if ka == 0 { + ka = timeout } - es, err := newEventStreamer(depth, s.KeepAliveInterval) - if err != nil { - httputil.HandleError(w, err.Error(), http.StatusInternalServerError) - return + buffSize := s.EventFeedDepth + if buffSize == 0 { + buffSize = DefaultEventFeedDepth } + api.SetSSEHeaders(w) + sw := newStreamingResponseController(w, timeout) ctx, cancel := context.WithCancel(ctx) defer cancel() - api.SetSSEHeaders(sw) + es := newEventStreamer(buffSize, ka) + go es.outboxWriteLoop(ctx, cancel, sw) if err := es.recvEventLoop(ctx, cancel, topics, s); err != nil { log.WithError(err).Debug("Shutting down StreamEvents handler.") } + cleanupStart := time.Now() + es.waitForExit() + log.WithField("cleanup_wait", time.Since(cleanupStart)).Debug("streamEvents shutdown complete") } -func newEventStreamer(buffSize int, ka time.Duration) (*eventStreamer, error) { - if ka == 0 { - ka = time.Duration(params.BeaconConfig().SecondsPerSlot) * time.Second - } +func newEventStreamer(buffSize int, ka time.Duration) *eventStreamer { return &eventStreamer{ - outbox: make(chan lazyReader, buffSize), - keepAlive: ka, - }, nil + outbox: make(chan lazyReader, buffSize), + keepAlive: ka, + openUntilExit: make(chan struct{}), + } } type eventStreamer struct { - outbox chan lazyReader - keepAlive time.Duration + outbox chan lazyReader + keepAlive time.Duration + openUntilExit chan struct{} } func (es *eventStreamer) recvEventLoop(ctx context.Context, cancel context.CancelFunc, req *topicRequest, s *Server) error { + defer close(es.outbox) + defer cancel() eventsChan := make(chan *feed.Event, len(es.outbox)) if req.needOpsFeed { opsSub := s.OperationNotifier.OperationFeed().Subscribe(eventsChan) @@ -228,7 +225,6 @@ func (es *eventStreamer) recvEventLoop(ctx context.Context, cancel context.Cance // channel should stay relatively empty, which gives this loop time to unsubscribe // and cleanup before the event stream channel fills and disrupts other readers. if err := es.safeWrite(ctx, lr); err != nil { - cancel() // note: we could hijack the connection and close it here. Does that cause issues? What are the benefits? // A benefit of hijack and close is that it may force an error on the remote end, however just closing the context of the // http handler may be sufficient to cause the remote http response reader to close. @@ -265,12 +261,13 @@ func newlineReader() io.Reader { // outboxWriteLoop runs in a separate goroutine. Its job is to write the values in the outbox to // the client as fast as the client can read them. -func (es *eventStreamer) outboxWriteLoop(ctx context.Context, cancel context.CancelFunc, w StreamingResponseWriter) { +func (es *eventStreamer) outboxWriteLoop(ctx context.Context, cancel context.CancelFunc, w *streamingResponseWriterController) { var err error defer func() { if err != nil { log.WithError(err).Debug("Event streamer shutting down due to error.") } + es.exit() }() defer cancel() // Write a keepalive at the start to test the connection and simplify test setup. @@ -310,18 +307,43 @@ func (es *eventStreamer) outboxWriteLoop(ctx context.Context, cancel context.Can } } -func writeLazyReaderWithRecover(w StreamingResponseWriter, lr lazyReader) (err error) { +func (es *eventStreamer) exit() { + drained := 0 + for range es.outbox { + drained += 1 + } + log.WithField("undelivered_events", drained).Debug("Event stream outbox drained.") + close(es.openUntilExit) +} + +// waitForExit blocks until the outboxWriteLoop has exited. +// While this function blocks, it is not yet safe to exit the http handler, +// because the outboxWriteLoop may still be writing to the http ResponseWriter. +func (es *eventStreamer) waitForExit() { + <-es.openUntilExit +} + +func writeLazyReaderWithRecover(w *streamingResponseWriterController, lr lazyReader) (err error) { defer func() { if r := recover(); r != nil { log.WithField("panic", r).Error("Recovered from panic while writing event to client.") err = errWriterUnusable } }() - _, err = io.Copy(w, lr()) + r := lr() + out, err := io.ReadAll(r) + if err != nil { + return err + } + _, err = w.Write(out) return err } -func (es *eventStreamer) writeOutbox(ctx context.Context, w StreamingResponseWriter, first lazyReader) error { +func (es *eventStreamer) writeOutbox(ctx context.Context, w *streamingResponseWriterController, first lazyReader) error { + // The outboxWriteLoop is responsible for managing the keep-alive timer and toggling between reading from the outbox + // when it is ready, only allowing the keep-alive to fire when there hasn't been a write in the keep-alive interval. + // Since outboxWriteLoop will get either the first event or the keep-alive, we let it pass in the first event to write, + // either the event's lazyReader, or nil for a keep-alive. needKeepAlive := true if first != nil { if err := writeLazyReaderWithRecover(w, first); err != nil { @@ -337,6 +359,11 @@ func (es *eventStreamer) writeOutbox(ctx context.Context, w StreamingResponseWri case <-ctx.Done(): return ctx.Err() case rf := <-es.outbox: + // We don't want to call Flush until we've exhausted all the writes - it's always preferrable to + // just keep draining the outbox and rely on the underlying Write code to flush+block when it + // needs to based on buffering. Whenever we fill the buffer with a string of writes, the underlying + // code will flush on its own, so it's better to explicitly flush only once, after we've totally + // drained the outbox, to catch any dangling bytes stuck in a buffer. if err := writeLazyReaderWithRecover(w, rf); err != nil { return err } @@ -347,8 +374,7 @@ func (es *eventStreamer) writeOutbox(ctx context.Context, w StreamingResponseWri return err } } - w.Flush() - return nil + return w.Flush() } } } @@ -638,3 +664,51 @@ func (s *Server) currentPayloadAttributes(ctx context.Context) (lazyReader, erro }) }, nil } + +func newStreamingResponseController(rw http.ResponseWriter, timeout time.Duration) *streamingResponseWriterController { + rc := http.NewResponseController(rw) + return &streamingResponseWriterController{ + timeout: timeout, + rw: rw, + rc: rc, + } +} + +// streamingResponseWriterController provides an interface similar to an http.ResponseWriter, +// wrapping an http.ResponseWriter and an http.ResponseController, using the ResponseController +// to set and clear deadlines for Write and Flush methods, and delegating to the underlying +// types to Write and Flush. +type streamingResponseWriterController struct { + timeout time.Duration + rw http.ResponseWriter + rc *http.ResponseController +} + +func (c *streamingResponseWriterController) Write(b []byte) (int, error) { + if err := c.setDeadline(); err != nil { + return 0, err + } + out, err := c.rw.Write(b) + if err != nil { + return out, err + } + return out, c.clearDeadline() +} + +func (c *streamingResponseWriterController) setDeadline() error { + return c.rc.SetWriteDeadline(time.Now().Add(c.timeout)) +} + +func (c *streamingResponseWriterController) clearDeadline() error { + return c.rc.SetWriteDeadline(time.Time{}) +} + +func (c *streamingResponseWriterController) Flush() error { + if err := c.setDeadline(); err != nil { + return err + } + if err := c.rc.Flush(); err != nil { + return err + } + return c.clearDeadline() +} diff --git a/beacon-chain/rpc/eth/events/events_test.go b/beacon-chain/rpc/eth/events/events_test.go index 420f34fdf96d..32daf1c7218f 100644 --- a/beacon-chain/rpc/eth/events/events_test.go +++ b/beacon-chain/rpc/eth/events/events_test.go @@ -27,9 +27,12 @@ import ( "github.com/prysmaticlabs/prysm/v5/testing/require" "github.com/prysmaticlabs/prysm/v5/testing/util" sse "github.com/r3labs/sse/v2" + "github.com/sirupsen/logrus" ) -func requireAllEventsReceived(t *testing.T, stn, opn *mockChain.EventFeedWrapper, events []*feed.Event, req *topicRequest, s *Server, w *StreamingResponseWriterRecorder) { +var testEventWriteTimeout = 100 * time.Millisecond + +func requireAllEventsReceived(t *testing.T, stn, opn *mockChain.EventFeedWrapper, events []*feed.Event, req *topicRequest, s *Server, w *StreamingResponseWriterRecorder, logs chan *logrus.Entry) { // maxBufferSize param copied from sse lib client code sseR := sse.NewEventStreamReader(w.Body(), 1<<24) ctx, cancel := context.WithTimeout(context.Background(), time.Second) @@ -77,21 +80,29 @@ func requireAllEventsReceived(t *testing.T, stn, opn *mockChain.EventFeedWrapper } } }() - select { - case <-done: - break - case <-ctx.Done(): - t.Fatalf("context canceled / timed out waiting for events, err=%v", ctx.Err()) + for { + select { + case entry := <-logs: + errAttr, ok := entry.Data[logrus.ErrorKey] + if ok { + t.Errorf("unexpected error in logs: %v", errAttr) + } + case <-done: + require.Equal(t, 0, len(expected), "expected events not seen") + return + case <-ctx.Done(): + t.Fatalf("context canceled / timed out waiting for events, err=%v", ctx.Err()) + } } - require.Equal(t, 0, len(expected), "expected events not seen") } -func (tr *topicRequest) testHttpRequest(_ *testing.T) *http.Request { +func (tr *topicRequest) testHttpRequest(ctx context.Context, _ *testing.T) *http.Request { tq := make([]string, 0, len(tr.topics)) for topic := range tr.topics { tq = append(tq, "topics="+topic) } - return httptest.NewRequest(http.MethodGet, fmt.Sprintf("http://example.com/eth/v1/events?%s", strings.Join(tq, "&")), nil) + req := httptest.NewRequest(http.MethodGet, fmt.Sprintf("http://example.com/eth/v1/events?%s", strings.Join(tq, "&")), nil) + return req.WithContext(ctx) } func operationEventsFixtures(t *testing.T) (*topicRequest, []*feed.Event) { @@ -235,31 +246,77 @@ func operationEventsFixtures(t *testing.T) (*topicRequest, []*feed.Event) { } } +type streamTestSync struct { + done chan struct{} + cancel func() + undo func() + logs chan *logrus.Entry + ctx context.Context + t *testing.T +} + +func (s *streamTestSync) cleanup() { + s.cancel() + select { + case <-s.done: + case <-time.After(10 * time.Millisecond): + s.t.Fatal("timed out waiting for handler to finish") + } + s.undo() +} + +func (s *streamTestSync) markDone() { + close(s.done) +} + +func newStreamTestSync(t *testing.T) *streamTestSync { + logChan := make(chan *logrus.Entry, 100) + cew := util.NewChannelEntryWriter(logChan) + undo := util.RegisterHookWithUndo(logger, cew) + ctx, cancel := context.WithCancel(context.Background()) + return &streamTestSync{ + t: t, + ctx: ctx, + cancel: cancel, + logs: logChan, + undo: undo, + done: make(chan struct{}), + } +} + func TestStreamEvents_OperationsEvents(t *testing.T) { t.Run("operations", func(t *testing.T) { + testSync := newStreamTestSync(t) + defer testSync.cleanup() stn := mockChain.NewEventFeedWrapper() opn := mockChain.NewEventFeedWrapper() s := &Server{ StateNotifier: &mockChain.SimpleNotifier{Feed: stn}, OperationNotifier: &mockChain.SimpleNotifier{Feed: opn}, + EventWriteTimeout: testEventWriteTimeout, } topics, events := operationEventsFixtures(t) - request := topics.testHttpRequest(t) - w := NewStreamingResponseWriterRecorder() + request := topics.testHttpRequest(testSync.ctx, t) + w := NewStreamingResponseWriterRecorder(testSync.ctx) go func() { s.StreamEvents(w, request) + testSync.markDone() }() - requireAllEventsReceived(t, stn, opn, events, topics, s, w) + requireAllEventsReceived(t, stn, opn, events, topics, s, w, testSync.logs) }) t.Run("state", func(t *testing.T) { + testSync := newStreamTestSync(t) + defer testSync.cleanup() + stn := mockChain.NewEventFeedWrapper() opn := mockChain.NewEventFeedWrapper() s := &Server{ StateNotifier: &mockChain.SimpleNotifier{Feed: stn}, OperationNotifier: &mockChain.SimpleNotifier{Feed: opn}, + EventWriteTimeout: testEventWriteTimeout, } topics, err := newTopicRequest([]string{ @@ -269,8 +326,8 @@ func TestStreamEvents_OperationsEvents(t *testing.T) { BlockTopic, }) require.NoError(t, err) - request := topics.testHttpRequest(t) - w := NewStreamingResponseWriterRecorder() + request := topics.testHttpRequest(testSync.ctx, t) + w := NewStreamingResponseWriterRecorder(testSync.ctx) b, err := blocks.NewSignedBeaconBlock(util.HydrateSignedBeaconBlock(ð.SignedBeaconBlock{})) require.NoError(t, err) @@ -323,9 +380,10 @@ func TestStreamEvents_OperationsEvents(t *testing.T) { go func() { s.StreamEvents(w, request) + testSync.markDone() }() - requireAllEventsReceived(t, stn, opn, events, topics, s, w) + requireAllEventsReceived(t, stn, opn, events, topics, s, w, testSync.logs) }) t.Run("payload attributes", func(t *testing.T) { type testCase struct { @@ -396,59 +454,93 @@ func TestStreamEvents_OperationsEvents(t *testing.T) { }, } for _, tc := range testCases { - st := tc.getState() - v := ð.Validator{ExitEpoch: math.MaxUint64} - require.NoError(t, st.SetValidators([]*eth.Validator{v})) - currentSlot := primitives.Slot(0) - // to avoid slot processing - require.NoError(t, st.SetSlot(currentSlot+1)) - b := tc.getBlock() - mockChainService := &mockChain.ChainService{ - Root: make([]byte, 32), - State: st, - Block: b, - Slot: ¤tSlot, - } + t.Run(tc.name, func(t *testing.T) { + testSync := newStreamTestSync(t) + defer testSync.cleanup() - stn := mockChain.NewEventFeedWrapper() - opn := mockChain.NewEventFeedWrapper() - s := &Server{ - StateNotifier: &mockChain.SimpleNotifier{Feed: stn}, - OperationNotifier: &mockChain.SimpleNotifier{Feed: opn}, - HeadFetcher: mockChainService, - ChainInfoFetcher: mockChainService, - TrackedValidatorsCache: cache.NewTrackedValidatorsCache(), - } - if tc.SetTrackedValidatorsCache != nil { - tc.SetTrackedValidatorsCache(s.TrackedValidatorsCache) - } - topics, err := newTopicRequest([]string{PayloadAttributesTopic}) - require.NoError(t, err) - request := topics.testHttpRequest(t) - w := NewStreamingResponseWriterRecorder() - events := []*feed.Event{&feed.Event{Type: statefeed.MissedSlot}} - - go func() { - s.StreamEvents(w, request) - }() - requireAllEventsReceived(t, stn, opn, events, topics, s, w) + st := tc.getState() + v := ð.Validator{ExitEpoch: math.MaxUint64} + require.NoError(t, st.SetValidators([]*eth.Validator{v})) + currentSlot := primitives.Slot(0) + // to avoid slot processing + require.NoError(t, st.SetSlot(currentSlot+1)) + b := tc.getBlock() + mockChainService := &mockChain.ChainService{ + Root: make([]byte, 32), + State: st, + Block: b, + Slot: ¤tSlot, + } + + stn := mockChain.NewEventFeedWrapper() + opn := mockChain.NewEventFeedWrapper() + s := &Server{ + StateNotifier: &mockChain.SimpleNotifier{Feed: stn}, + OperationNotifier: &mockChain.SimpleNotifier{Feed: opn}, + HeadFetcher: mockChainService, + ChainInfoFetcher: mockChainService, + TrackedValidatorsCache: cache.NewTrackedValidatorsCache(), + EventWriteTimeout: testEventWriteTimeout, + } + if tc.SetTrackedValidatorsCache != nil { + tc.SetTrackedValidatorsCache(s.TrackedValidatorsCache) + } + topics, err := newTopicRequest([]string{PayloadAttributesTopic}) + require.NoError(t, err) + request := topics.testHttpRequest(testSync.ctx, t) + w := NewStreamingResponseWriterRecorder(testSync.ctx) + events := []*feed.Event{&feed.Event{Type: statefeed.MissedSlot}} + + go func() { + s.StreamEvents(w, request) + testSync.markDone() + }() + requireAllEventsReceived(t, stn, opn, events, topics, s, w, testSync.logs) + }) } }) } -func TestStuckReader(t *testing.T) { +func TestStuckReaderScenarios(t *testing.T) { + cases := []struct { + name string + queueDepth func([]*feed.Event) int + }{ + { + name: "slow reader - queue overflows", + queueDepth: func(events []*feed.Event) int { + return len(events) - 1 + }, + }, + { + name: "slow reader - all queued, but writer is stuck, write timeout", + queueDepth: func(events []*feed.Event) int { + return len(events) + 1 + }, + }, + } + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + wedgedWriterTestCase(t, c.queueDepth) + }) + } +} + +func wedgedWriterTestCase(t *testing.T, queueDepth func([]*feed.Event) int) { topics, events := operationEventsFixtures(t) require.Equal(t, 8, len(events)) + // set eventFeedDepth to a number lower than the events we intend to send to force the server to drop the reader. stn := mockChain.NewEventFeedWrapper() opn := mockChain.NewEventFeedWrapper() s := &Server{ + EventWriteTimeout: 10 * time.Millisecond, StateNotifier: &mockChain.SimpleNotifier{Feed: stn}, OperationNotifier: &mockChain.SimpleNotifier{Feed: opn}, - EventFeedDepth: len(events) - 1, + EventFeedDepth: queueDepth(events), } - ctx, cancel := context.WithTimeout(context.Background(), time.Second) + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) defer cancel() eventsWritten := make(chan struct{}) go func() { @@ -468,8 +560,8 @@ func TestStuckReader(t *testing.T) { close(eventsWritten) }() - request := topics.testHttpRequest(t) - w := NewStreamingResponseWriterRecorder() + request := topics.testHttpRequest(ctx, t) + w := NewStreamingResponseWriterRecorder(ctx) handlerFinished := make(chan struct{}) go func() { diff --git a/beacon-chain/rpc/eth/events/http_test.go b/beacon-chain/rpc/eth/events/http_test.go index 1bfaaa873da3..afff92518a32 100644 --- a/beacon-chain/rpc/eth/events/http_test.go +++ b/beacon-chain/rpc/eth/events/http_test.go @@ -1,10 +1,12 @@ package events import ( + "context" "io" "net/http" "net/http/httptest" "testing" + "time" "github.com/prysmaticlabs/prysm/v5/testing/require" ) @@ -17,32 +19,66 @@ type StreamingResponseWriterRecorder struct { status chan int bodyRecording []byte flushed bool + writeDeadline time.Time + ctx context.Context } func (w *StreamingResponseWriterRecorder) StatusChan() chan int { return w.status } -func NewStreamingResponseWriterRecorder() *StreamingResponseWriterRecorder { +func NewStreamingResponseWriterRecorder(ctx context.Context) *StreamingResponseWriterRecorder { r, w := io.Pipe() return &StreamingResponseWriterRecorder{ ResponseWriter: httptest.NewRecorder(), r: r, w: w, status: make(chan int, 1), + ctx: ctx, } } // Write implements http.ResponseWriter. func (w *StreamingResponseWriterRecorder) Write(data []byte) (int, error) { w.WriteHeader(http.StatusOK) - n, err := w.w.Write(data) + written, err := writeWithDeadline(w.ctx, w.w, data, w.writeDeadline) if err != nil { - return n, err + return written, err } + // The test response writer is non-blocking. return w.ResponseWriter.Write(data) } +var zeroTimeValue = time.Time{} + +func writeWithDeadline(ctx context.Context, w io.Writer, data []byte, deadline time.Time) (int, error) { + result := struct { + written int + err error + }{} + done := make(chan struct{}) + go func() { + defer close(done) + result.written, result.err = w.Write(data) + }() + if deadline == zeroTimeValue { + select { + case <-ctx.Done(): + return 0, ctx.Err() + case <-done: + return result.written, result.err + } + } + select { + case <-time.After(time.Until(deadline)): + return 0, http.ErrHandlerTimeout + case <-done: + return result.written, result.err + case <-ctx.Done(): + return 0, ctx.Err() + } +} + // WriteHeader implements http.ResponseWriter. func (w *StreamingResponseWriterRecorder) WriteHeader(statusCode int) { if w.statusWritten != nil { @@ -65,6 +101,7 @@ func (w *StreamingResponseWriterRecorder) RequireStatus(t *testing.T, status int } func (w *StreamingResponseWriterRecorder) Flush() { + w.WriteHeader(200) fw, ok := w.ResponseWriter.(http.Flusher) if ok { fw.Flush() @@ -72,4 +109,9 @@ func (w *StreamingResponseWriterRecorder) Flush() { w.flushed = true } +func (w *StreamingResponseWriterRecorder) SetWriteDeadline(d time.Time) error { + w.writeDeadline = d + return nil +} + var _ http.ResponseWriter = &StreamingResponseWriterRecorder{} diff --git a/beacon-chain/rpc/eth/events/log.go b/beacon-chain/rpc/eth/events/log.go new file mode 100644 index 000000000000..6d218a4f034a --- /dev/null +++ b/beacon-chain/rpc/eth/events/log.go @@ -0,0 +1,6 @@ +package events + +import "github.com/sirupsen/logrus" + +var logger = logrus.StandardLogger() +var log = logger.WithField("prefix", "events") diff --git a/beacon-chain/rpc/eth/events/server.go b/beacon-chain/rpc/eth/events/server.go index 26e83454e5c4..6b4e4b787f07 100644 --- a/beacon-chain/rpc/eth/events/server.go +++ b/beacon-chain/rpc/eth/events/server.go @@ -22,4 +22,5 @@ type Server struct { TrackedValidatorsCache *cache.TrackedValidatorsCache KeepAliveInterval time.Duration EventFeedDepth int + EventWriteTimeout time.Duration } diff --git a/testing/util/BUILD.bazel b/testing/util/BUILD.bazel index fb194bfc1093..16154398cf7e 100644 --- a/testing/util/BUILD.bazel +++ b/testing/util/BUILD.bazel @@ -21,6 +21,7 @@ go_library( "electra_state.go", "helpers.go", "lightclient.go", + "logging.go", "merge.go", "state.go", "sync_aggregate.go", @@ -69,6 +70,7 @@ go_library( "@com_github_pkg_errors//:go_default_library", "@com_github_prysmaticlabs_go_bitfield//:go_default_library", "@com_github_sirupsen_logrus//:go_default_library", + "@com_github_sirupsen_logrus//hooks/test:go_default_library", "@io_bazel_rules_go//go/tools/bazel:go_default_library", ], ) @@ -83,6 +85,7 @@ go_test( "deneb_test.go", "deposits_test.go", "helpers_test.go", + "logging_test.go", "state_test.go", ], embed = [":go_default_library"], @@ -106,6 +109,8 @@ go_test( "//testing/assert:go_default_library", "//testing/require:go_default_library", "//time/slots:go_default_library", + "@com_github_pkg_errors//:go_default_library", + "@com_github_sirupsen_logrus//:go_default_library", "@org_golang_google_protobuf//proto:go_default_library", ], ) diff --git a/testing/util/logging.go b/testing/util/logging.go new file mode 100644 index 000000000000..a4da28453d0b --- /dev/null +++ b/testing/util/logging.go @@ -0,0 +1,90 @@ +package util + +import ( + "github.com/sirupsen/logrus" + "github.com/sirupsen/logrus/hooks/test" +) + +// ComparableHook is an interface that allows hooks to be uniquely identified +// so that tests can safely unregister them as part of cleanup. +type ComparableHook interface { + logrus.Hook + Equal(other logrus.Hook) bool +} + +// UnregisterHook removes a hook that implements the HookIdentifier interface +// from all levels of the given logger. +func UnregisterHook(logger *logrus.Logger, unregister ComparableHook) { + found := false + replace := make(logrus.LevelHooks) + for lvl, hooks := range logger.Hooks { + for _, h := range hooks { + if unregister.Equal(h) { + found = true + continue + } + replace[lvl] = append(replace[lvl], h) + } + } + if !found { + return + } + logger.ReplaceHooks(replace) +} + +var highestLevel logrus.Level + +// RegisterHookWithUndo adds a hook to the logger and +// returns a function that can be called to remove it. This is intended to be used in tests +// to ensure that test hooks are removed after the test is complete. +func RegisterHookWithUndo(logger *logrus.Logger, hook ComparableHook) func() { + level := logger.Level + logger.AddHook(hook) + // set level to highest possible to ensure that hook is called for all log levels + logger.SetLevel(highestLevel) + return func() { + UnregisterHook(logger, hook) + logger.SetLevel(level) + } +} + +// NewChannelEntryWriter creates a new ChannelEntryWriter. +// The channel argument will be sent all log entries. +// Note that if this is an unbuffered channel, it is the responsibility +// of the code using it to make sure that it is drained appropriately, +// or calls to the logger can block. +func NewChannelEntryWriter(c chan *logrus.Entry) *ChannelEntryWriter { + return &ChannelEntryWriter{c: c} +} + +// ChannelEntryWriter embeds/wraps the test.Hook struct +// and adds a channel to receive log entries every time the +// Fire method of the Hook interface is called. +type ChannelEntryWriter struct { + test.Hook + c chan *logrus.Entry +} + +// Fire delegates to the embedded test.Hook Fire method after +// sending the log entry to the channel. +func (c *ChannelEntryWriter) Fire(e *logrus.Entry) error { + if c.c != nil { + c.c <- e + } + return c.Hook.Fire(e) +} + +func (c *ChannelEntryWriter) Equal(other logrus.Hook) bool { + return c == other +} + +var _ logrus.Hook = &ChannelEntryWriter{} +var _ ComparableHook = &ChannelEntryWriter{} + +func init() { + for _, level := range logrus.AllLevels { + if level > highestLevel { + highestLevel = level + } + } +} diff --git a/testing/util/logging_test.go b/testing/util/logging_test.go new file mode 100644 index 000000000000..d596abbfc466 --- /dev/null +++ b/testing/util/logging_test.go @@ -0,0 +1,78 @@ +package util + +import ( + "testing" + "time" + + "github.com/pkg/errors" + "github.com/prysmaticlabs/prysm/v5/testing/require" + "github.com/sirupsen/logrus" +) + +func TestUnregister(t *testing.T) { + logger := logrus.New() + logger.SetLevel(logrus.PanicLevel) // set to lowest log level to test level override in + assertNoHooks(t, logger) + c := make(chan *logrus.Entry, 1) + tl := NewChannelEntryWriter(c) + undo := RegisterHookWithUndo(logger, tl) + assertRegistered(t, logger, tl) + logger.Trace("test") + select { + case <-c: + default: + t.Fatalf("Expected log entry, got none") + } + undo() + assertNoHooks(t, logger) + require.Equal(t, logrus.PanicLevel, logger.Level) +} + +var logTestErr = errors.New("test") + +func TestChannelEntryWriter(t *testing.T) { + logger := logrus.New() + c := make(chan *logrus.Entry) + tl := NewChannelEntryWriter(c) + logger.AddHook(tl) + msg := "test" + go func() { + logger.WithError(logTestErr).Info(msg) + }() + select { + case e := <-c: + gotErr := e.Data[logrus.ErrorKey] + if gotErr == nil { + t.Fatalf("Expected error in log entry, got nil") + } + ge, ok := gotErr.(error) + require.Equal(t, true, ok, "Expected error in log entry to be of type error, got %T", gotErr) + require.ErrorIs(t, ge, logTestErr) + require.Equal(t, msg, e.Message) + require.Equal(t, logrus.InfoLevel, e.Level) + case <-time.After(10 * time.Millisecond): + t.Fatalf("Timed out waiting for log entry") + } +} + +func assertNoHooks(t *testing.T, logger *logrus.Logger) { + for lvl, hooks := range logger.Hooks { + for _, hook := range hooks { + t.Fatalf("Expected no hooks, got %v at level %s", hook, lvl.String()) + } + } +} + +func assertRegistered(t *testing.T, logger *logrus.Logger, hook ComparableHook) { + for _, lvl := range hook.Levels() { + registered := logger.Hooks[lvl] + found := false + for _, h := range registered { + if hook.Equal(h) { + found = true + break + } + } + require.Equal(t, true, found, "Expected hook %v to be registered at level %s, but it was not", hook, lvl.String()) + } +}