Skip to content

Commit

Permalink
add option to configure used serve context
Browse files Browse the repository at this point in the history
  • Loading branch information
nicowolf91 committed Aug 5, 2024
1 parent 8dae0c4 commit 3e40144
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 4 deletions.
15 changes: 14 additions & 1 deletion server.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package gosse

import (
"context"
"net/http"
"time"
)
Expand All @@ -14,6 +15,7 @@ type Server struct {
listenersAdditionalHeader http.Header
listenersKeepAliveInterval time.Duration
listenersKeepAliveMessage Messager
serveContext func(r *http.Request) context.Context
}

func NewServer(optionSetters ...ServerOptionSetter) *Server {
Expand All @@ -26,6 +28,7 @@ func NewServer(optionSetters ...ServerOptionSetter) *Server {
listenersAdditionalHeader: http.Header{},
listenersKeepAliveInterval: 10 * time.Second,
listenersKeepAliveMessage: DefaultKeepAliveMessage,
serveContext: defaultServeContext,
}

for _, setter := range optionSetters {
Expand Down Expand Up @@ -68,7 +71,7 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
stream := s.messageBroker.Subscribe(channelID)

serveRequest(
r.Context(),
s.serveContext(r),
sseRequest,
s.messageConverter,
stream,
Expand Down Expand Up @@ -151,3 +154,13 @@ func WithListenersKeepAliveMessage(msg Messager) ServerOptionSetter {
}
}
}

func WithServeContext(f func(r *http.Request) context.Context) ServerOptionSetter {
return func(server *Server) {
server.serveContext = f
}
}

var defaultServeContext = func(r *http.Request) context.Context {
return r.Context()
}
23 changes: 20 additions & 3 deletions server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ import (
"net/http"
"net/http/httptest"
"net/url"
"reflect"
"runtime"
"sync"
"testing"
"time"
Expand All @@ -24,10 +26,10 @@ func TestServerConstructorDefaults(t *testing.T) {
WithListenersAdditionalHeader(http.Header{}),
WithListenersKeepAliveInterval(10*time.Second),
WithListenersKeepAliveMessage(DefaultKeepAliveMessage),
WithServeContext(defaultServeContext),
)

server := NewServer()
assert.Equal(t, server, serverWithDefaultValues)
assert.Equal(t, DefaultChannelIDExtractor, serverWithDefaultValues.channelIDExtractor)
assert.Equal(t, DefaultMessageToBytesConverter, serverWithDefaultValues.messageConverter)
assert.Equal(t, NewChannelBroker[string, Messager](), serverWithDefaultValues.messageBroker)
Expand All @@ -36,6 +38,9 @@ func TestServerConstructorDefaults(t *testing.T) {
assert.Equal(t, http.Header{}, serverWithDefaultValues.listenersAdditionalHeader)
assert.Equal(t, 10*time.Second, serverWithDefaultValues.listenersKeepAliveInterval)
assert.Equal(t, DefaultKeepAliveMessage, serverWithDefaultValues.listenersKeepAliveMessage)
funcName1 := runtime.FuncForPC(reflect.ValueOf(defaultServeContext).Pointer()).Name()
funcName2 := runtime.FuncForPC(reflect.ValueOf(server.serveContext).Pointer()).Name()
assert.Equal(t, funcName1, funcName2)
}

func TestServerConstructor(t *testing.T) {
Expand All @@ -48,6 +53,9 @@ func TestServerConstructor(t *testing.T) {
additionalHeader.Set("test", "123")
keepAliveInterval := 1337 * time.Millisecond
keepAliveMessage := NewMessage().WithData([]byte(": stay awake!"))
serveContextFunc := func(r *http.Request) context.Context {
return context.TODO()
}

server := NewServer(
WithChannelIDExtractor(extractor),
Expand All @@ -58,6 +66,7 @@ func TestServerConstructor(t *testing.T) {
WithListenersAdditionalHeader(additionalHeader),
WithListenersKeepAliveInterval(keepAliveInterval),
WithListenersKeepAliveMessage(keepAliveMessage),
WithServeContext(serveContextFunc),
)

assert.Equal(t, extractor, server.channelIDExtractor)
Expand All @@ -68,6 +77,9 @@ func TestServerConstructor(t *testing.T) {
assert.Equal(t, additionalHeader, server.listenersAdditionalHeader)
assert.Equal(t, keepAliveInterval, server.listenersKeepAliveInterval)
assert.Equal(t, keepAliveMessage, server.listenersKeepAliveMessage)
funcName1 := runtime.FuncForPC(reflect.ValueOf(serveContextFunc).Pointer()).Name()
funcName2 := runtime.FuncForPC(reflect.ValueOf(server.serveContext).Pointer()).Name()
assert.Equal(t, funcName1, funcName2)
}

func TestServer_PublishBroadcast(t *testing.T) {
Expand Down Expand Up @@ -136,15 +148,20 @@ func TestServer_ServeHTTP(t *testing.T) {

request := httptest.NewRequest("GET", requestUrl.String(), nil)
request.Header.Set("Last-Event-ID", "1337")
ctx, cancel := context.WithCancel(request.Context())
ctx, cancel := context.WithCancel(context.Background())
cancel()

replayMessages := []Messager{
NewMessage().WithData([]byte("replay 1")),
NewMessage().WithData([]byte("replay 2")),
}
replayer := &messageReplayerMock{ret: replayMessages}
server := NewServer(WithMessageReplayer(replayer))
server := NewServer(
WithMessageReplayer(replayer),
WithServeContext(func(r *http.Request) context.Context {
return ctx
}),
)

recorder := httptest.NewRecorder()
server.ServeHTTP(recorder, request.WithContext(ctx))
Expand Down

0 comments on commit 3e40144

Please sign in to comment.