diff --git a/docs/content/recipes/cors.md b/docs/content/recipes/cors.md index 0b0ac114bfa..129768410d3 100644 --- a/docs/content/recipes/cors.md +++ b/docs/content/recipes/cors.md @@ -40,7 +40,7 @@ func main() { srv := handler.NewDefaultServer(starwars.NewExecutableSchema(starwars.NewResolver())) srv.AddTransport(&transport.Websocket{ - Upgrader: websocket.Upgrader{ + Upgrader: &websocket.Upgrader{ CheckOrigin: func(r *http.Request) bool { // Check against your desired domains here return r.Host == "example.org" diff --git a/example/chat/server/server.go b/example/chat/server/server.go index 1592f6b0ab5..e07595319ca 100644 --- a/example/chat/server/server.go +++ b/example/chat/server/server.go @@ -34,7 +34,7 @@ func main() { srv.AddTransport(transport.POST{}) srv.AddTransport(transport.Websocket{ KeepAlivePingInterval: 10 * time.Second, - Upgrader: websocket.Upgrader{ + Upgrader: &websocket.Upgrader{ CheckOrigin: func(r *http.Request) bool { return true }, diff --git a/graphql/handler/transport/websocket.go b/graphql/handler/transport/websocket.go index 3089a877999..e27f86ca785 100644 --- a/graphql/handler/transport/websocket.go +++ b/graphql/handler/transport/websocket.go @@ -31,7 +31,7 @@ const ( type ( Websocket struct { - Upgrader websocket.Upgrader + Upgrader WebsocketUpgrader InitFunc WebsocketInitFunc KeepAlivePingInterval time.Duration } @@ -51,6 +51,9 @@ type ( ID string `json:"id,omitempty"` Type string `json:"type"` } + WebsocketUpgrader interface { + Upgrade(w http.ResponseWriter, r *http.Request, responseHeader http.Header) (*websocket.Conn, error) + } WebsocketInitFunc func(ctx context.Context, initPayload InitPayload) (context.Context, error) ) diff --git a/handler/handler.go b/handler/handler.go index 892df53986a..8eb2680a655 100644 --- a/handler/handler.go +++ b/handler/handler.go @@ -11,7 +11,6 @@ import ( "github.com/99designs/gqlgen/graphql/handler/lru" "github.com/99designs/gqlgen/graphql/handler/transport" "github.com/99designs/gqlgen/graphql/playground" - "github.com/gorilla/websocket" ) // Deprecated: switch to graphql/handler.New @@ -74,7 +73,7 @@ func GraphQL(exec graphql.ExecutableSchema, options ...Option) http.HandlerFunc // Deprecated: switch to graphql/handler.New type Config struct { cacheSize int - upgrader websocket.Upgrader + upgrader transport.WebsocketUpgrader websocketInitFunc transport.WebsocketInitFunc connectionKeepAlivePingInterval time.Duration recover graphql.RecoverFunc @@ -93,7 +92,7 @@ type Config struct { type Option func(cfg *Config) // Deprecated: switch to graphql/handler.New -func WebsocketUpgrader(upgrader websocket.Upgrader) Option { +func WebsocketUpgrader(upgrader transport.WebsocketUpgrader) Option { return func(cfg *Config) { cfg.upgrader = upgrader }