Skip to content

Commit

Permalink
fix: improve ws subprotocol selection
Browse files Browse the repository at this point in the history
  • Loading branch information
JivusAyrus committed May 12, 2024
1 parent 22c2f9a commit 1fc0dd9
Show file tree
Hide file tree
Showing 10 changed files with 37 additions and 36 deletions.
2 changes: 0 additions & 2 deletions execution/engine/config_factory_federation.go
Original file line number Diff line number Diff line change
Expand Up @@ -447,15 +447,13 @@ func (f *FederationEngineConfigFactory) subscriptionClient(
httpClient,
streamingClient,
f.engineCtx,
graphql_datasource.WithWSSubProtocol(graphql_datasource.ProtocolGraphQLTWS),
)
default:
// for compatibility reasons we fall back to graphql-ws protocol
graphqlSubscriptionClient = subscriptionClientFactory.NewSubscriptionClient(
httpClient,
streamingClient,
f.engineCtx,
graphql_datasource.WithWSSubProtocol(graphql_datasource.ProtocolGraphQLWS),
)
}

Expand Down
2 changes: 0 additions & 2 deletions execution/engine/engine_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -150,15 +150,13 @@ func (d *graphqlDataSourceGenerator) generateSubscriptionClient(httpClient *http
httpClient,
definedOptions.streamingClient,
nil,
graphql_datasource.WithWSSubProtocol(graphql_datasource.ProtocolGraphQLTWS),
)
default:
// for compatibility reasons we fall back to graphql-ws protocol
graphqlSubscriptionClient = definedOptions.subscriptionClientFactory.NewSubscriptionClient(
httpClient,
definedOptions.streamingClient,
nil,
graphql_datasource.WithWSSubProtocol(graphql_datasource.ProtocolGraphQLWS),
)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ type SubscriptionConfiguration struct {
// which connections can be multiplexed together, but the subscription engine does not forward
// these headers by itself.
ForwardedClientHeaderRegularExpressions []*regexp.Regexp
WsSubProtocol string
}

type FetchConfiguration struct {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -362,6 +362,7 @@ func (p *Planner[T]) ConfigureSubscription() plan.SubscriptionConfiguration {
input = httpclient.SetInputFlag(input, httpclient.SSE_METHOD_POST)
}
}
input = httpclient.SetInputWSSubprotocol(input, []byte(p.config.subscription.WsSubProtocol))

header, err := json.Marshal(p.config.subscription.Header)
if err == nil && len(header) != 0 && !bytes.Equal(header, literal.NULL) {
Expand Down Expand Up @@ -1668,6 +1669,7 @@ type GraphQLSubscriptionOptions struct {
SSEMethodPost bool `json:"sse_method_post"`
ForwardedClientHeaderNames []string `json:"forwarded_client_header_names"`
ForwardedClientHeaderRegularExpressions []*regexp.Regexp `json:"forwarded_client_header_regular_expressions"`
WsSubProtocol string `json:"ws_sub_protocol"`
}

type GraphQLBody struct {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8851,7 +8851,7 @@ func TestSubscription_GTWS_SubProtocol(t *testing.T) {
newSubscriptionSource := func(ctx context.Context) SubscriptionSource {
httpClient := http.Client{}
subscriptionSource := SubscriptionSource{
client: NewGraphQLSubscriptionClient(&httpClient, http.DefaultClient, ctx, WithWSSubProtocol(ProtocolGraphQLTWS)),
client: NewGraphQLSubscriptionClient(&httpClient, http.DefaultClient, ctx),
}
return subscriptionSource
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,25 @@ type subscriptionClient struct {
hashPool sync.Pool
handlers map[uint64]ConnectionHandler
handlersMu sync.Mutex
wsSubProtocol string
onWsConnectionInitCallback *OnWsConnectionInitCallback

readTimeout time.Duration
}

type InvalidWsSubprotocolError struct {
Message string
}

func (e InvalidWsSubprotocolError) Error() string {
return e.Message
}

func NewInvalidWsSubprotocolError(message string) InvalidWsSubprotocolError {
return InvalidWsSubprotocolError{
Message: message,
}
}

type Options func(options *opts)

func WithLogger(log abstractlogger.Logger) Options {
Expand All @@ -50,12 +63,6 @@ func WithReadTimeout(timeout time.Duration) Options {
}
}

func WithWSSubProtocol(protocol string) Options {
return func(options *opts) {
options.wsSubProtocol = protocol
}
}

func WithOnWsConnectionInitCallback(callback *OnWsConnectionInitCallback) Options {
return func(options *opts) {
options.onWsConnectionInitCallback = callback
Expand All @@ -65,7 +72,6 @@ func WithOnWsConnectionInitCallback(callback *OnWsConnectionInitCallback) Option
type opts struct {
readTimeout time.Duration
log abstractlogger.Logger
wsSubProtocol string
onWsConnectionInitCallback *OnWsConnectionInitCallback
}

Expand Down Expand Up @@ -106,7 +112,6 @@ func NewGraphQLSubscriptionClient(httpClient, streamingClient *http.Client, engi
return xxhash.New()
},
},
wsSubProtocol: op.wsSubProtocol,
onWsConnectionInitCallback: op.onWsConnectionInitCallback,
}
}
Expand Down Expand Up @@ -288,8 +293,8 @@ func (c *subscriptionClient) requestHash(ctx *resolve.Context, options GraphQLSu

func (c *subscriptionClient) newWSConnectionHandler(reqCtx context.Context, options GraphQLSubscriptionOptions) (ConnectionHandler, error) {
subProtocols := []string{ProtocolGraphQLWS, ProtocolGraphQLTWS}
if c.wsSubProtocol != "" {
subProtocols = []string{c.wsSubProtocol}
if options.WsSubProtocol != "" && options.WsSubProtocol != "auto" {
subProtocols = []string{options.WsSubProtocol}
}

conn, upgradeResponse, err := websocket.Dial(reqCtx, options.URL, &websocket.DialOptions{
Expand Down Expand Up @@ -333,21 +338,25 @@ func (c *subscriptionClient) newWSConnectionHandler(reqCtx context.Context, opti
return nil, err
}

if c.wsSubProtocol == "" {
c.wsSubProtocol = conn.Subprotocol()
wsSubProtocol := subProtocols[0]
if options.WsSubProtocol == "" || options.WsSubProtocol == "auto" {
wsSubProtocol = conn.Subprotocol()
if wsSubProtocol == "" {
wsSubProtocol = ProtocolGraphQLWS
}
}

if err := waitForAck(reqCtx, conn); err != nil {
return nil, err
}

switch c.wsSubProtocol {
switch wsSubProtocol {
case ProtocolGraphQLWS:
return newGQLWSConnectionHandler(c.engineCtx, conn, c.readTimeout, c.log), nil
case ProtocolGraphQLTWS:
return newGQLTWSConnectionHandler(c.engineCtx, conn, c.readTimeout, c.log), nil
default:
return nil, fmt.Errorf("unknown protocol %s", conn.Subprotocol())
return nil, NewInvalidWsSubprotocolError(fmt.Sprintf("provided websocket subprotocol %s is not supported. The supported subprotocols are graphql-ws and graphql-transport-ws. Please configure your subsciptions with the mentioned subprotocols", wsSubProtocol))
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,6 @@ func TestWebsocketSubscriptionClientDeDuplication(t *testing.T) {
client := NewGraphQLSubscriptionClient(http.DefaultClient, http.DefaultClient, serverCtx,
WithReadTimeout(time.Millisecond),
WithLogger(logger()),
WithWSSubProtocol(ProtocolGraphQLWS),
)
clientsDone := &sync.WaitGroup{}

Expand Down Expand Up @@ -215,7 +214,6 @@ func TestWebsocketSubscriptionClientImmediateClientCancel(t *testing.T) {
client := NewGraphQLSubscriptionClient(http.DefaultClient, http.DefaultClient, serverCtx,
WithReadTimeout(time.Millisecond),
WithLogger(logger()),
WithWSSubProtocol(ProtocolGraphQLWS),
).(*subscriptionClient)
updater := &testSubscriptionUpdater{}
err := client.Subscribe(resolve.NewContext(ctx), GraphQLSubscriptionOptions{
Expand Down Expand Up @@ -270,7 +268,6 @@ func TestWebsocketSubscriptionClientWithServerDisconnect(t *testing.T) {
client := NewGraphQLSubscriptionClient(http.DefaultClient, http.DefaultClient, serverCtx,
WithReadTimeout(time.Millisecond),
WithLogger(logger()),
WithWSSubProtocol(ProtocolGraphQLWS),
).(*subscriptionClient)
updater := &testSubscriptionUpdater{}
err := client.Subscribe(resolve.NewContext(ctx), GraphQLSubscriptionOptions{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,6 @@ func TestWebsocketSubscriptionClient_GQLTWS(t *testing.T) {
client := NewGraphQLSubscriptionClient(http.DefaultClient, http.DefaultClient, serverCtx,
WithReadTimeout(time.Millisecond),
WithLogger(logger()),
WithWSSubProtocol(ProtocolGraphQLTWS),
).(*subscriptionClient)

updater := &testSubscriptionUpdater{}
Expand Down Expand Up @@ -142,7 +141,6 @@ func TestWebsocketSubscriptionClientPing_GQLTWS(t *testing.T) {
client := NewGraphQLSubscriptionClient(http.DefaultClient, http.DefaultClient, serverCtx,
WithReadTimeout(time.Millisecond),
WithLogger(logger()),
WithWSSubProtocol(ProtocolGraphQLTWS),
).(*subscriptionClient)

updater := &testSubscriptionUpdater{}
Expand Down Expand Up @@ -210,7 +208,6 @@ func TestWebsocketSubscriptionClientError_GQLTWS(t *testing.T) {
client := NewGraphQLSubscriptionClient(http.DefaultClient, http.DefaultClient, serverCtx,
WithReadTimeout(time.Millisecond),
WithLogger(logger()),
WithWSSubProtocol(ProtocolGraphQLTWS),
)

updater := &testSubscriptionUpdater{}
Expand Down Expand Up @@ -298,7 +295,6 @@ func TestWebSocketSubscriptionClientInitIncludePing_GQLTWS(t *testing.T) {
client := NewGraphQLSubscriptionClient(http.DefaultClient, http.DefaultClient, serverCtx,
WithReadTimeout(time.Millisecond),
WithLogger(logger()),
WithWSSubProtocol(ProtocolGraphQLTWS),
).(*subscriptionClient)
updater := &testSubscriptionUpdater{}
err := client.Subscribe(resolve.NewContext(ctx), GraphQLSubscriptionOptions{
Expand Down Expand Up @@ -373,7 +369,6 @@ func TestWebsocketSubscriptionClient_GQLTWS_Upstream_Dies(t *testing.T) {
client := NewGraphQLSubscriptionClient(http.DefaultClient, http.DefaultClient, serverCtx,
WithReadTimeout(time.Second),
WithLogger(logger()),
WithWSSubProtocol(ProtocolGraphQLTWS),
).(*subscriptionClient)

updater := &testSubscriptionUpdater{}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,6 @@ func TestWebSocketSubscriptionClientInitIncludeKA_GQLWS(t *testing.T) {
client := NewGraphQLSubscriptionClient(http.DefaultClient, http.DefaultClient, serverCtx,
WithReadTimeout(time.Millisecond),
WithLogger(logger()),
WithWSSubProtocol(ProtocolGraphQLWS),
).(*subscriptionClient)
updater := &testSubscriptionUpdater{}
err := client.Subscribe(resolve.NewContext(ctx), GraphQLSubscriptionOptions{
Expand Down Expand Up @@ -144,7 +143,6 @@ func TestWebsocketSubscriptionClient_GQLWS(t *testing.T) {
client := NewGraphQLSubscriptionClient(http.DefaultClient, http.DefaultClient, serverCtx,
WithReadTimeout(time.Millisecond),
WithLogger(logger()),
WithWSSubProtocol(ProtocolGraphQLWS),
).(*subscriptionClient)
updater := &testSubscriptionUpdater{}
err := client.Subscribe(resolve.NewContext(ctx), GraphQLSubscriptionOptions{
Expand Down Expand Up @@ -208,7 +206,6 @@ func TestWebsocketSubscriptionClientErrorArray(t *testing.T) {
client := NewGraphQLSubscriptionClient(http.DefaultClient, http.DefaultClient, serverCtx,
WithReadTimeout(time.Millisecond),
WithLogger(logger()),
WithWSSubProtocol(ProtocolGraphQLWS),
)
updater := &testSubscriptionUpdater{}
err := client.Subscribe(resolve.NewContext(clientCtx), GraphQLSubscriptionOptions{
Expand Down Expand Up @@ -264,7 +261,6 @@ func TestWebsocketSubscriptionClientErrorObject(t *testing.T) {
client := NewGraphQLSubscriptionClient(http.DefaultClient, http.DefaultClient, serverCtx,
WithReadTimeout(time.Millisecond),
WithLogger(logger()),
WithWSSubProtocol(ProtocolGraphQLWS),
)
updater := &testSubscriptionUpdater{}
err := client.Subscribe(resolve.NewContext(clientCtx), GraphQLSubscriptionOptions{
Expand Down Expand Up @@ -329,7 +325,6 @@ func TestWebsocketSubscriptionClient_GQLWS_Upstream_Dies(t *testing.T) {
client := NewGraphQLSubscriptionClient(http.DefaultClient, http.DefaultClient, serverCtx,
WithReadTimeout(time.Second),
WithLogger(logger()),
WithWSSubProtocol(ProtocolGraphQLWS),
).(*subscriptionClient)
updater := &testSubscriptionUpdater{}
err := client.Subscribe(resolve.NewContext(ctx), GraphQLSubscriptionOptions{
Expand Down Expand Up @@ -381,7 +376,6 @@ func TestWebsocketConnectionReuse(t *testing.T) {
client := NewGraphQLSubscriptionClient(http.DefaultClient, http.DefaultClient, serverCtx,
WithReadTimeout(time.Millisecond),
WithLogger(logger()),
WithWSSubProtocol(ProtocolGraphQLWS),
).(*subscriptionClient)

updater := &testSubscriptionUpdater{}
Expand Down Expand Up @@ -432,7 +426,6 @@ func TestWebsocketConnectionReuse(t *testing.T) {
client := NewGraphQLSubscriptionClient(http.DefaultClient, http.DefaultClient, serverCtx,
WithReadTimeout(time.Millisecond),
WithLogger(logger()),
WithWSSubProtocol(ProtocolGraphQLWS),
).(*subscriptionClient)

updater := &testSubscriptionUpdater{}
Expand Down Expand Up @@ -471,7 +464,6 @@ func TestWebsocketConnectionReuse(t *testing.T) {
client := NewGraphQLSubscriptionClient(http.DefaultClient, http.DefaultClient, serverCtx,
WithReadTimeout(time.Millisecond),
WithLogger(logger()),
WithWSSubProtocol(ProtocolGraphQLWS),
).(*subscriptionClient)

updater := &testSubscriptionUpdater{}
Expand Down
9 changes: 9 additions & 0 deletions v2/pkg/engine/datasource/httpclient/httpclient.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ const (
FORWARDED_CLIENT_HEADER_NAMES = "forwarded_client_header_names"
FORWARDED_CLIENT_HEADER_REGULAR_EXPRESSIONS = "forwarded_client_header_regular_expressions"
TRACE = "__trace__"
WsSubProtocol = "ws_sub_protocol"
)

var (
Expand Down Expand Up @@ -118,6 +119,14 @@ func SetInputFlag(input []byte, flagName string) []byte {
return out
}

func SetInputWSSubprotocol(input, wsSubProtocol []byte) []byte {
if len(wsSubProtocol) == 0 {
return input
}
out, _ := sjson.SetRawBytes(input, WsSubProtocol, wrapQuotesIfString(wsSubProtocol))
return out
}

func IsInputFlagSet(input []byte, flagName string) bool {
value, dataType, _, err := jsonparser.Get(input, flagName)
if err != nil {
Expand Down

0 comments on commit 1fc0dd9

Please sign in to comment.