Skip to content

Commit

Permalink
Use http.NewResponseController()
Browse files Browse the repository at this point in the history
  • Loading branch information
pbacsko committed Nov 29, 2023
1 parent 498e2a8 commit 1cd63b8
Show file tree
Hide file tree
Showing 3 changed files with 83 additions and 7 deletions.
5 changes: 5 additions & 0 deletions pkg/webservice/handler_mock_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package webservice

import (
"net/http"
"time"
)

// InternalMetricHistory needs resetting between tests
Expand Down Expand Up @@ -49,3 +50,7 @@ func (trw *MockResponseWriter) Write(bytes []byte) (int, error) {
func (trw *MockResponseWriter) WriteHeader(statusCode int) {
trw.statusCode = statusCode
}

func (trw *MockResponseWriter) SetWriteDeadline(deadline time.Time) error {
return nil
}
15 changes: 15 additions & 0 deletions pkg/webservice/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import (
"sort"
"strconv"
"strings"
"time"

"github.com/julienschmidt/httprouter"
"github.com/prometheus/client_golang/prometheus/promhttp"
Expand Down Expand Up @@ -943,6 +944,12 @@ func getStream(w http.ResponseWriter, r *http.Request) {
}
}

rc := http.NewResponseController(w)
err := rc.SetWriteDeadline(time.Time{})
if err != nil {
buildJSONErrorResponse(w, fmt.Sprintf("Cannot set write deadline: %v", err), http.StatusInternalServerError)
return
}
enc := json.NewEncoder(w)
stream := eventSystem.CreateEventStream(r.Host, count)

Expand All @@ -957,6 +964,14 @@ func getStream(w http.ResponseWriter, r *http.Request) {
eventSystem.RemoveStream(stream)
return
case e, ok := <-stream.Events:
err := rc.SetWriteDeadline(time.Now().Add(5 * time.Second))
if err != nil {
// should not fail at this point
buildJSONErrorResponse(w, fmt.Sprintf("Cannot set write deadline: %v", err), http.StatusInternalServerError)
eventSystem.RemoveStream(stream)
return
}

if !ok {
// the channel was closed by the event system itself
msg := "Event stream was closed by the producer"
Expand Down
70 changes: 63 additions & 7 deletions pkg/webservice/handlers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ package webservice
import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
Expand Down Expand Up @@ -1477,7 +1478,7 @@ func TestGetStream(t *testing.T) {
defer cancel()
req = req.Clone(cancelCtx)

resp := httptest.NewRecorder() // MockResponseWriter does not implement http.Flusher
resp := NewResponseRecorderWithDeadline() // MockResponseWriter does not implement http.Flusher

go func() {
time.Sleep(200 * time.Millisecond)
Expand Down Expand Up @@ -1516,7 +1517,7 @@ func TestGetStream_StreamClosedByProducer(t *testing.T) {
var req *http.Request
req, err := http.NewRequest("GET", "/ws/v1/events/stream", strings.NewReader(""))
assert.NilError(t, err)
resp := httptest.NewRecorder() // MockResponseWriter does not implement http.Flusher
resp := NewResponseRecorderWithDeadline() // MockResponseWriter does not implement http.Flusher

go func() {
time.Sleep(200 * time.Millisecond)
Expand All @@ -1540,10 +1541,6 @@ func TestGetStream_StreamClosedByProducer(t *testing.T) {
}

func TestGetStream_NotFlusherImpl(t *testing.T) {
events.Init()
ev := events.GetEventSystem().(*events.EventSystemImpl) //nolint:errcheck
ev.StartServiceWithPublisher(false)

var req *http.Request
req, err := http.NewRequest("GET", "/ws/v1/events/stream", strings.NewReader(""))
assert.NilError(t, err)
Expand All @@ -1566,7 +1563,7 @@ func TestGetStream_Count(t *testing.T) {
cancelCtx, cancel := context.WithCancel(context.Background())
defer cancel()
req = req.Clone(cancelCtx)
resp := httptest.NewRecorder() // MockResponseWriter does not implement http.Flusher
resp := NewResponseRecorderWithDeadline() // MockResponseWriter does not implement http.Flusher

// add some existing events
ev.AddEvent(&si.EventRecord{TimestampNano: 0})
Expand Down Expand Up @@ -1647,6 +1644,47 @@ func TestGetStream_TrackingDisabled(t *testing.T) {
assertYunikornError(t, line, "Event tracking is disabled")
}

func TestGetStream_NoWriteDeadline(t *testing.T) {
events.Init()
ev := events.GetEventSystem().(*events.EventSystemImpl) //nolint:errcheck
ev.StartServiceWithPublisher(false)

var req *http.Request
req, err := http.NewRequest("GET", "/ws/v1/events/stream", strings.NewReader(""))
assert.NilError(t, err)
resp := httptest.NewRecorder() // does not have SetWriteDeadline()

getStream(resp, req)

output := make([]byte, 256)
n, err := resp.Body.Read(output)
assert.NilError(t, err)
line := string(output[:n])
assertYunikornError(t, line, "Cannot set write deadline: feature not supported")
assert.Equal(t, http.StatusInternalServerError, resp.Code)
}

func TestGetStream_SetWriteDeadlineFails(t *testing.T) {
events.Init()
ev := events.GetEventSystem().(*events.EventSystemImpl) //nolint:errcheck
ev.StartServiceWithPublisher(false)

var req *http.Request
req, err := http.NewRequest("GET", "/ws/v1/events/stream", strings.NewReader(""))
assert.NilError(t, err)
resp := NewResponseRecorderWithDeadline()
resp.setWriteFails = true

getStream(resp, req)

output := make([]byte, 256)
n, err := resp.Body.Read(output)
assert.NilError(t, err)
line := string(output[:n])
assertYunikornError(t, line, "Cannot set write deadline: SetWriteDeadline failed")
assert.Equal(t, http.StatusInternalServerError, resp.Code)
}

func assertEvent(t *testing.T, output string, tsNano int64, objectID string) {
t.Helper()
var evt si.EventRecord
Expand Down Expand Up @@ -1834,3 +1872,21 @@ func verifyStateDumpJSON(t *testing.T, aggregated *AggregatedStateInfo) {
assert.Check(t, len(aggregated.Config.SchedulerConfig.Partitions) > 0)
assert.Check(t, len(aggregated.Config.Extra) > 0)
}

type ResponseRecorderWithDeadline struct {
*httptest.ResponseRecorder
setWriteFails bool
}

func (rrd *ResponseRecorderWithDeadline) SetWriteDeadline(_ time.Time) error {
if rrd.setWriteFails {
return errors.New("SetWriteDeadline failed")
}
return nil
}

func NewResponseRecorderWithDeadline() *ResponseRecorderWithDeadline {
return &ResponseRecorderWithDeadline{
ResponseRecorder: httptest.NewRecorder(),
}
}

0 comments on commit 1cd63b8

Please sign in to comment.