From 1cd63b804e815d8aad6c193d38bdadbc807bbbdd Mon Sep 17 00:00:00 2001 From: Peter Bacsko Date: Wed, 29 Nov 2023 12:42:00 +0100 Subject: [PATCH] Use http.NewResponseController() --- pkg/webservice/handler_mock_test.go | 5 +++ pkg/webservice/handlers.go | 15 +++++++ pkg/webservice/handlers_test.go | 70 ++++++++++++++++++++++++++--- 3 files changed, 83 insertions(+), 7 deletions(-) diff --git a/pkg/webservice/handler_mock_test.go b/pkg/webservice/handler_mock_test.go index 9759afae9..439efe4fc 100644 --- a/pkg/webservice/handler_mock_test.go +++ b/pkg/webservice/handler_mock_test.go @@ -19,6 +19,7 @@ package webservice import ( "net/http" + "time" ) // InternalMetricHistory needs resetting between tests @@ -49,3 +50,7 @@ func (trw *MockResponseWriter) Write(bytes []byte) (int, error) { func (trw *MockResponseWriter) WriteHeader(statusCode int) { trw.statusCode = statusCode } + +func (trw *MockResponseWriter) SetWriteDeadline(deadline time.Time) error { + return nil +} diff --git a/pkg/webservice/handlers.go b/pkg/webservice/handlers.go index b0755d392..d1da10fb2 100644 --- a/pkg/webservice/handlers.go +++ b/pkg/webservice/handlers.go @@ -28,6 +28,7 @@ import ( "sort" "strconv" "strings" + "time" "github.com/julienschmidt/httprouter" "github.com/prometheus/client_golang/prometheus/promhttp" @@ -943,6 +944,12 @@ func getStream(w http.ResponseWriter, r *http.Request) { } } + rc := http.NewResponseController(w) + err := rc.SetWriteDeadline(time.Time{}) + if err != nil { + buildJSONErrorResponse(w, fmt.Sprintf("Cannot set write deadline: %v", err), http.StatusInternalServerError) + return + } enc := json.NewEncoder(w) stream := eventSystem.CreateEventStream(r.Host, count) @@ -957,6 +964,14 @@ func getStream(w http.ResponseWriter, r *http.Request) { eventSystem.RemoveStream(stream) return case e, ok := <-stream.Events: + err := rc.SetWriteDeadline(time.Now().Add(5 * time.Second)) + if err != nil { + // should not fail at this point + buildJSONErrorResponse(w, fmt.Sprintf("Cannot set write deadline: %v", err), http.StatusInternalServerError) + eventSystem.RemoveStream(stream) + return + } + if !ok { // the channel was closed by the event system itself msg := "Event stream was closed by the producer" diff --git a/pkg/webservice/handlers_test.go b/pkg/webservice/handlers_test.go index 9488c0d21..5d683c317 100644 --- a/pkg/webservice/handlers_test.go +++ b/pkg/webservice/handlers_test.go @@ -21,6 +21,7 @@ package webservice import ( "context" "encoding/json" + "errors" "fmt" "io" "net/http" @@ -1477,7 +1478,7 @@ func TestGetStream(t *testing.T) { defer cancel() req = req.Clone(cancelCtx) - resp := httptest.NewRecorder() // MockResponseWriter does not implement http.Flusher + resp := NewResponseRecorderWithDeadline() // MockResponseWriter does not implement http.Flusher go func() { time.Sleep(200 * time.Millisecond) @@ -1516,7 +1517,7 @@ func TestGetStream_StreamClosedByProducer(t *testing.T) { var req *http.Request req, err := http.NewRequest("GET", "/ws/v1/events/stream", strings.NewReader("")) assert.NilError(t, err) - resp := httptest.NewRecorder() // MockResponseWriter does not implement http.Flusher + resp := NewResponseRecorderWithDeadline() // MockResponseWriter does not implement http.Flusher go func() { time.Sleep(200 * time.Millisecond) @@ -1540,10 +1541,6 @@ func TestGetStream_StreamClosedByProducer(t *testing.T) { } func TestGetStream_NotFlusherImpl(t *testing.T) { - events.Init() - ev := events.GetEventSystem().(*events.EventSystemImpl) //nolint:errcheck - ev.StartServiceWithPublisher(false) - var req *http.Request req, err := http.NewRequest("GET", "/ws/v1/events/stream", strings.NewReader("")) assert.NilError(t, err) @@ -1566,7 +1563,7 @@ func TestGetStream_Count(t *testing.T) { cancelCtx, cancel := context.WithCancel(context.Background()) defer cancel() req = req.Clone(cancelCtx) - resp := httptest.NewRecorder() // MockResponseWriter does not implement http.Flusher + resp := NewResponseRecorderWithDeadline() // MockResponseWriter does not implement http.Flusher // add some existing events ev.AddEvent(&si.EventRecord{TimestampNano: 0}) @@ -1647,6 +1644,47 @@ func TestGetStream_TrackingDisabled(t *testing.T) { assertYunikornError(t, line, "Event tracking is disabled") } +func TestGetStream_NoWriteDeadline(t *testing.T) { + events.Init() + ev := events.GetEventSystem().(*events.EventSystemImpl) //nolint:errcheck + ev.StartServiceWithPublisher(false) + + var req *http.Request + req, err := http.NewRequest("GET", "/ws/v1/events/stream", strings.NewReader("")) + assert.NilError(t, err) + resp := httptest.NewRecorder() // does not have SetWriteDeadline() + + getStream(resp, req) + + output := make([]byte, 256) + n, err := resp.Body.Read(output) + assert.NilError(t, err) + line := string(output[:n]) + assertYunikornError(t, line, "Cannot set write deadline: feature not supported") + assert.Equal(t, http.StatusInternalServerError, resp.Code) +} + +func TestGetStream_SetWriteDeadlineFails(t *testing.T) { + events.Init() + ev := events.GetEventSystem().(*events.EventSystemImpl) //nolint:errcheck + ev.StartServiceWithPublisher(false) + + var req *http.Request + req, err := http.NewRequest("GET", "/ws/v1/events/stream", strings.NewReader("")) + assert.NilError(t, err) + resp := NewResponseRecorderWithDeadline() + resp.setWriteFails = true + + getStream(resp, req) + + output := make([]byte, 256) + n, err := resp.Body.Read(output) + assert.NilError(t, err) + line := string(output[:n]) + assertYunikornError(t, line, "Cannot set write deadline: SetWriteDeadline failed") + assert.Equal(t, http.StatusInternalServerError, resp.Code) +} + func assertEvent(t *testing.T, output string, tsNano int64, objectID string) { t.Helper() var evt si.EventRecord @@ -1834,3 +1872,21 @@ func verifyStateDumpJSON(t *testing.T, aggregated *AggregatedStateInfo) { assert.Check(t, len(aggregated.Config.SchedulerConfig.Partitions) > 0) assert.Check(t, len(aggregated.Config.Extra) > 0) } + +type ResponseRecorderWithDeadline struct { + *httptest.ResponseRecorder + setWriteFails bool +} + +func (rrd *ResponseRecorderWithDeadline) SetWriteDeadline(_ time.Time) error { + if rrd.setWriteFails { + return errors.New("SetWriteDeadline failed") + } + return nil +} + +func NewResponseRecorderWithDeadline() *ResponseRecorderWithDeadline { + return &ResponseRecorderWithDeadline{ + ResponseRecorder: httptest.NewRecorder(), + } +}