Skip to content

Commit

Permalink
rpc: rename wsMessageSizeLimit (it's now a default), minor test-changes
Browse files Browse the repository at this point in the history
  • Loading branch information
holiman committed Sep 5, 2023
1 parent e28e681 commit d6a104c
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 26 deletions.
2 changes: 1 addition & 1 deletion rpc/client_opt.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ type clientConfig struct {

// WebSocket options
wsDialer *websocket.Dialer
wsMessageSizeLimit *int64
wsMessageSizeLimit *int64 // wsMessageSizeLimit nil = default, 0 = no limit

// RPC handler options
idgen func() ID
Expand Down
14 changes: 6 additions & 8 deletions rpc/websocket.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ const (
wsPingInterval = 30 * time.Second
wsPingWriteTimeout = 5 * time.Second
wsPongTimeout = 30 * time.Second
wsMessageSizeLimit = 32 * 1024 * 1024
wsDefaultReadLimit = 32 * 1024 * 1024
)

var wsBufferPool = new(sync.Pool)
Expand All @@ -60,7 +60,7 @@ func (s *Server) WebsocketHandler(allowedOrigins []string) http.Handler {
log.Debug("WebSocket upgrade failed", "err", err)
return
}
codec := newWebsocketCodec(conn, r.Host, r.Header, wsMessageSizeLimit)
codec := newWebsocketCodec(conn, r.Host, r.Header, wsDefaultReadLimit)
s.ServeCodec(codec, 0)
})
}
Expand Down Expand Up @@ -251,11 +251,9 @@ func newClientTransportWS(endpoint string, cfg *clientConfig) (reconnectFunc, er
}
return nil, hErr
}
var messageSizeLimit int64
if cfg.wsMessageSizeLimit != nil {
messageSizeLimit := int64(wsDefaultReadLimit)
if cfg.wsMessageSizeLimit != nil && *cfg.wsMessageSizeLimit >= 0 {
messageSizeLimit = *cfg.wsMessageSizeLimit
} else {
messageSizeLimit = wsMessageSizeLimit
}
return newWebsocketCodec(conn, dialURL, header, messageSizeLimit), nil
}
Expand Down Expand Up @@ -289,8 +287,8 @@ type websocketCodec struct {
pongReceived chan struct{}
}

func newWebsocketCodec(conn *websocket.Conn, host string, req http.Header, messageSizeLimit int64) ServerCodec {
conn.SetReadLimit(messageSizeLimit)
func newWebsocketCodec(conn *websocket.Conn, host string, req http.Header, readLimit int64) ServerCodec {
conn.SetReadLimit(readLimit)
encode := func(v interface{}, isErrorResponse bool) error {
return conn.WriteJSON(v)
}
Expand Down
48 changes: 31 additions & 17 deletions rpc/websocket_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -125,38 +125,52 @@ func TestWebsocketLargeRead(t *testing.T) {
defer srv.Stop()
defer httpsrv.Close()

testLimit := func(limit int64) {
testLimit := func(limit *int64) {
opts := []ClientOption{}
if limit >= 0 {
opts = append(opts, WithWebsocketMessageSizeLimit(limit))
} else {
limit = wsMessageSizeLimit
expLimit := int64(wsDefaultReadLimit)
if limit != nil && *limit >= 0 {
opts = append(opts, WithWebsocketMessageSizeLimit(*limit))
if *limit > 0 {
expLimit = *limit // 0 means infinite
}
}
client, err := DialOptions(context.Background(), wsURL, opts...)
if err != nil {
t.Fatalf("can't dial: %v", err)
}
defer client.Close()

// Remove some bytes for json encoding overhead.
underLimit := int(limit - 128)
underLimit := int(expLimit - 128)
overLimit := expLimit + 1
if expLimit == wsDefaultReadLimit {
// No point trying the full 32MB in tests. Just sanity-check that
// it's not obviously limited.
underLimit = 1024
overLimit = -1
}
var res string
err = client.Call(&res, "test_repeat", "A", underLimit)
if err != nil {
t.Fatalf("unexpected error with limit %d: %v", limit, err)
// Check under limit
if err = client.Call(&res, "test_repeat", "A", underLimit); err != nil {
t.Fatalf("unexpected error with limit %d: %v", expLimit, err)
}
if len(res) != underLimit || strings.Count(res, "A") != underLimit {
t.Fatal("incorrect data")
}

err = client.Call(&res, "test_repeat", "A", limit+1)
if err == nil || err != websocket.ErrReadLimit {
t.Fatalf("wrong error with limit %d: %v expecting %v", limit, err, websocket.ErrReadLimit)
// Check over limit
if overLimit > 0 {
err = client.Call(&res, "test_repeat", "A", expLimit+1)
if err == nil || err != websocket.ErrReadLimit {
t.Fatalf("wrong error with limit %d: %v expecting %v", expLimit, err, websocket.ErrReadLimit)
}
}
}
ptr := func(v int64) *int64 { return &v }

testLimit(-1)
testLimit(wsMessageSizeLimit * 2)
testLimit(ptr(-1)) // Should be ignored (use default)
testLimit(ptr(0)) // Should be ignored (use default)
testLimit(nil) // Should be ignored (use default)
testLimit(ptr(200))
testLimit(ptr(wsDefaultReadLimit * 2))
}

func TestWebsocketPeerInfo(t *testing.T) {
Expand Down Expand Up @@ -252,7 +266,7 @@ func TestClientWebsocketLargeMessage(t *testing.T) {
defer srv.Stop()
defer httpsrv.Close()

respLength := wsMessageSizeLimit - 50
respLength := wsDefaultReadLimit - 50
srv.RegisterName("test", largeRespService{respLength})

c, err := DialWebsocket(context.Background(), wsURL, "")
Expand Down

0 comments on commit d6a104c

Please sign in to comment.