Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

limit the number of concurrent incoming streams #66

Merged
merged 1 commit into from
Nov 20, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions mux.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,11 @@ type Config struct {
// an expectation that things will move along quickly.
ConnectionWriteTimeout time.Duration

// MaxIncomingStreams is maximum number of concurrent incoming streams
// that we accept. If the peer tries to open more streams, those will be
// reset immediately.
MaxIncomingStreams uint32

// InitialStreamWindowSize is used to control the initial
// window size that we allow for a stream.
InitialStreamWindowSize uint32
Expand Down Expand Up @@ -65,6 +70,7 @@ func DefaultConfig() *Config {
EnableKeepAlive: true,
KeepAliveInterval: 30 * time.Second,
ConnectionWriteTimeout: 10 * time.Second,
MaxIncomingStreams: 1000,
InitialStreamWindowSize: initialStreamWindow,
MaxStreamWindowSize: maxStreamWindow,
LogOutput: os.Stderr,
Expand Down
23 changes: 18 additions & 5 deletions session.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ import (
"sync/atomic"
"time"

"github.com/libp2p/go-buffer-pool"
pool "github.com/libp2p/go-buffer-pool"
)

// Session is used to wrap a reliable ordered connection and to
Expand Down Expand Up @@ -55,9 +55,10 @@ type Session struct {
// streams maps a stream id to a stream, and inflight has an entry
// for any outgoing stream that has not yet been established. Both are
// protected by streamLock.
streams map[uint32]*Stream
inflight map[uint32]struct{}
streamLock sync.Mutex
numIncomingStreams uint32
streams map[uint32]*Stream
inflight map[uint32]struct{}
streamLock sync.Mutex

// synCh acts like a semaphore. It is sized to the AcceptBacklog which
// is assumed to be symmetric between the client and server. This allows
Expand Down Expand Up @@ -735,6 +736,15 @@ func (s *Session) incomingStream(id uint32) error {
return ErrDuplicateStream
}

if s.numIncomingStreams >= s.config.MaxIncomingStreams {
// too many active streams at the same time
s.logger.Printf("[WARN] yamux: MaxIncomingStreams exceeded, forcing stream reset")
delete(s.streams, id)
hdr := encode(typeWindowUpdate, flagRST, id, 0)
return s.sendMsg(hdr, nil, nil)
}

s.numIncomingStreams++
// Register the stream
s.streams[id] = stream

Expand All @@ -744,7 +754,7 @@ func (s *Session) incomingStream(id uint32) error {
return nil
default:
// Backlog exceeded! RST the stream
s.logger.Printf("[WARN] yamux: backlog exceeded, forcing connection reset")
s.logger.Printf("[WARN] yamux: backlog exceeded, forcing stream reset")
delete(s.streams, id)
hdr := encode(typeWindowUpdate, flagRST, id, 0)
return s.sendMsg(hdr, nil, nil)
Expand All @@ -764,6 +774,9 @@ func (s *Session) closeStream(id uint32) {
}
delete(s.inflight, id)
}
if s.client == (id%2 == 0) {
s.numIncomingStreams--
}
delete(s.streams, id)
s.streamLock.Unlock()
}
Expand Down
52 changes: 52 additions & 0 deletions session_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1732,3 +1732,55 @@ func TestInitialStreamWindow(t *testing.T) {
}
}
}

func TestMaxIncomingStreams(t *testing.T) {
const maxIncomingStreams = 5
conn1, conn2 := testConn()
client, err := Client(conn1, DefaultConfig())
require.NoError(t, err)
defer client.Close()

conf := DefaultConfig()
conf.MaxIncomingStreams = maxIncomingStreams
server, err := Server(conn2, conf)
require.NoError(t, err)
defer server.Close()

strChan := make(chan *Stream, maxIncomingStreams)
go func() {
defer close(strChan)
for {
str, err := server.AcceptStream()
if err != nil {
return
}
_, err = str.Write([]byte("foobar"))
require.NoError(t, err)
strChan <- str
}
}()

for i := 0; i < maxIncomingStreams; i++ {
str, err := client.OpenStream(context.Background())
require.NoError(t, err)
_, err = str.Read(make([]byte, 6))
require.NoError(t, err)
require.NoError(t, str.CloseWrite())
}
// The server now has maxIncomingStreams incoming streams.
// It will now reset the next stream that is opened.
str, err := client.OpenStream(context.Background())
require.NoError(t, err)
str.SetDeadline(time.Now().Add(time.Second))
_, err = str.Read([]byte{0})
require.EqualError(t, err, "stream reset")

// Now close one of the streams.
// This should then allow the client to open a new stream.
require.NoError(t, (<-strChan).Close())
str, err = client.OpenStream(context.Background())
require.NoError(t, err)
str.SetDeadline(time.Now().Add(time.Second))
_, err = str.Read([]byte{0})
require.NoError(t, err)
}