Skip to content

Add Grace to gracefully close WebSocket connections #200

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Feb 27, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 18 additions & 2 deletions accept.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,13 @@ func Accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (*Conn,
func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (_ *Conn, err error) {
defer errd.Wrap(&err, "failed to accept WebSocket connection")

g := graceFromRequest(r)
if g != nil && g.isShuttingdown() {
err := errors.New("server shutting down")
http.Error(w, err.Error(), http.StatusServiceUnavailable)
return nil, err
}

if opts == nil {
opts = &AcceptOptions{}
}
Expand Down Expand Up @@ -134,7 +141,7 @@ func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (_ *Con
b, _ := brw.Reader.Peek(brw.Reader.Buffered())
brw.Reader.Reset(io.MultiReader(bytes.NewReader(b), netConn))

return newConn(connConfig{
c := newConn(connConfig{
subprotocol: w.Header().Get("Sec-WebSocket-Protocol"),
rwc: netConn,
client: false,
Expand All @@ -143,7 +150,16 @@ func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (_ *Con

br: brw.Reader,
bw: brw.Writer,
}), nil
})

if g != nil {
err = g.addConn(c)
if err != nil {
return nil, err
}
}

return c, nil
}

func verifyClientRequest(w http.ResponseWriter, r *http.Request) (errCode int, _ error) {
Expand Down
6 changes: 6 additions & 0 deletions chat-example/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,9 @@ assets, the `/subscribe` WebSocket endpoint and the HTTP POST `/publish` endpoin

The code is well commented. I would recommend starting in `main.go` and then `chat.go` followed by
`index.html` and then `index.js`.

There are two automated tests for the server included in `chat_test.go`. The first is a simple one
client echo test. It publishes a single message and ensures it's received.

The second is a complex concurrency test where 10 clients send 128 unique messages
of max 128 bytes concurrently. The test ensures all messages are seen by every client.
106 changes: 80 additions & 26 deletions chat-example/chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,28 +3,75 @@ package main
import (
"context"
"errors"
"io"
"io/ioutil"
"log"
"net/http"
"sync"
"time"

"golang.org/x/time/rate"

"nhooyr.io/websocket"
)

// chatServer enables broadcasting to a set of subscribers.
type chatServer struct {
subscribersMu sync.RWMutex
subscribers map[chan<- []byte]struct{}
// subscriberMessageBuffer controls the max number
// of messages that can be queued for a subscriber
// before it is kicked.
//
// Defaults to 16.
subscriberMessageBuffer int

// publishLimiter controls the rate limit applied to the publish endpoint.
//
// Defaults to one publish every 100ms with a burst of 8.
publishLimiter *rate.Limiter

// logf controls where logs are sent.
// Defaults to log.Printf.
logf func(f string, v ...interface{})

// serveMux routes the various endpoints to the appropriate handler.
serveMux http.ServeMux

subscribersMu sync.Mutex
subscribers map[*subscriber]struct{}
}

// newChatServer constructs a chatServer with the defaults.
func newChatServer() *chatServer {
cs := &chatServer{
subscriberMessageBuffer: 16,
logf: log.Printf,
subscribers: make(map[*subscriber]struct{}),
publishLimiter: rate.NewLimiter(rate.Every(time.Millisecond*100), 8),
}
cs.serveMux.Handle("/", http.FileServer(http.Dir(".")))
cs.serveMux.HandleFunc("/subscribe", cs.subscribeHandler)
cs.serveMux.HandleFunc("/publish", cs.publishHandler)

return cs
}

// subscriber represents a subscriber.
// Messages are sent on the msgs channel and if the client
// cannot keep up with the messages, closeSlow is called.
type subscriber struct {
msgs chan []byte
closeSlow func()
}

func (cs *chatServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
cs.serveMux.ServeHTTP(w, r)
}

// subscribeHandler accepts the WebSocket connection and then subscribes
// it to all future messages.
func (cs *chatServer) subscribeHandler(w http.ResponseWriter, r *http.Request) {
c, err := websocket.Accept(w, r, nil)
if err != nil {
log.Print(err)
cs.logf("%v", err)
return
}
defer c.Close(websocket.StatusInternalError, "")
Expand All @@ -38,7 +85,8 @@ func (cs *chatServer) subscribeHandler(w http.ResponseWriter, r *http.Request) {
return
}
if err != nil {
log.Print(err)
cs.logf("%v", err)
return
}
}

Expand All @@ -49,19 +97,21 @@ func (cs *chatServer) publishHandler(w http.ResponseWriter, r *http.Request) {
http.Error(w, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed)
return
}
body := io.LimitReader(r.Body, 8192)
body := http.MaxBytesReader(w, r.Body, 8192)
msg, err := ioutil.ReadAll(body)
if err != nil {
http.Error(w, http.StatusText(http.StatusRequestEntityTooLarge), http.StatusRequestEntityTooLarge)
return
}

cs.publish(msg)

w.WriteHeader(http.StatusAccepted)
}

// subscribe subscribes the given WebSocket to all broadcast messages.
// It creates a msgs chan with a buffer of 16 to give some room to slower
// connections and then registers it. It then listens for all messages
// It creates a subscriber with a buffered msgs chan to give some room to slower
// connections and then registers the subscriber. It then listens for all messages
// and writes them to the WebSocket. If the context is cancelled or
// an error occurs, it returns and deletes the subscription.
//
Expand All @@ -70,13 +120,18 @@ func (cs *chatServer) publishHandler(w http.ResponseWriter, r *http.Request) {
func (cs *chatServer) subscribe(ctx context.Context, c *websocket.Conn) error {
ctx = c.CloseRead(ctx)

msgs := make(chan []byte, 16)
cs.addSubscriber(msgs)
defer cs.deleteSubscriber(msgs)
s := &subscriber{
msgs: make(chan []byte, cs.subscriberMessageBuffer),
closeSlow: func() {
c.Close(websocket.StatusPolicyViolation, "connection too slow to keep up with messages")
},
}
cs.addSubscriber(s)
defer cs.deleteSubscriber(s)

for {
select {
case msg := <-msgs:
case msg := <-s.msgs:
err := writeTimeout(ctx, time.Second*5, c, msg)
if err != nil {
return err
Expand All @@ -91,32 +146,31 @@ func (cs *chatServer) subscribe(ctx context.Context, c *websocket.Conn) error {
// It never blocks and so messages to slow subscribers
// are dropped.
func (cs *chatServer) publish(msg []byte) {
cs.subscribersMu.RLock()
defer cs.subscribersMu.RUnlock()
cs.subscribersMu.Lock()
defer cs.subscribersMu.Unlock()

cs.publishLimiter.Wait(context.Background())

for c := range cs.subscribers {
for s := range cs.subscribers {
select {
case c <- msg:
case s.msgs <- msg:
default:
go s.closeSlow()
}
}
}

// addSubscriber registers a subscriber with a channel
// on which to send messages.
func (cs *chatServer) addSubscriber(msgs chan<- []byte) {
// addSubscriber registers a subscriber.
func (cs *chatServer) addSubscriber(s *subscriber) {
cs.subscribersMu.Lock()
if cs.subscribers == nil {
cs.subscribers = make(map[chan<- []byte]struct{})
}
cs.subscribers[msgs] = struct{}{}
cs.subscribers[s] = struct{}{}
cs.subscribersMu.Unlock()
}

// deleteSubscriber deletes the subscriber with the given msgs channel.
func (cs *chatServer) deleteSubscriber(msgs chan []byte) {
// deleteSubscriber deletes the given subscriber.
func (cs *chatServer) deleteSubscriber(s *subscriber) {
cs.subscribersMu.Lock()
delete(cs.subscribers, msgs)
delete(cs.subscribers, s)
cs.subscribersMu.Unlock()
}

Expand Down
Loading