Skip to content

Commit

Permalink
Merge pull request #69 from libp2p/rcmgr
Browse files Browse the repository at this point in the history
add a MemoryManager
  • Loading branch information
marten-seemann authored Jan 17, 2022
2 parents 3a390d9 + 8d274da commit e7b8807
Show file tree
Hide file tree
Showing 5 changed files with 95 additions and 51 deletions.
4 changes: 2 additions & 2 deletions const.go
Original file line number Diff line number Diff line change
Expand Up @@ -115,8 +115,8 @@ const (
const (
// initialStreamWindow is the initial stream window size.
// It's not an implementation choice, the value defined in the specification.
initialStreamWindow uint32 = 256 * 1024
maxStreamWindow uint32 = 16 * 1024 * 1024
initialStreamWindow = 256 * 1024
maxStreamWindow = 16 * 1024 * 1024
)

const (
Expand Down
8 changes: 4 additions & 4 deletions mux.go
Original file line number Diff line number Diff line change
Expand Up @@ -109,25 +109,25 @@ func VerifyConfig(config *Config) error {
// Server is used to initialize a new server-side connection.
// There must be at most one server-side connection. If a nil config is
// provided, the DefaultConfiguration will be used.
func Server(conn net.Conn, config *Config) (*Session, error) {
func Server(conn net.Conn, config *Config, mm MemoryManager) (*Session, error) {
if config == nil {
config = DefaultConfig()
}
if err := VerifyConfig(config); err != nil {
return nil, err
}
return newSession(config, conn, false, config.ReadBufSize), nil
return newSession(config, conn, false, config.ReadBufSize, mm), nil
}

// Client is used to initialize a new client-side connection.
// There must be at most one client-side connection.
func Client(conn net.Conn, config *Config) (*Session, error) {
func Client(conn net.Conn, config *Config, mm MemoryManager) (*Session, error) {
if config == nil {
config = DefaultConfig()
}

if err := VerifyConfig(config); err != nil {
return nil, err
}
return newSession(config, conn, true, config.ReadBufSize), nil
return newSession(config, conn, true, config.ReadBufSize, mm), nil
}
98 changes: 70 additions & 28 deletions session.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,24 @@ import (
pool "github.com/libp2p/go-buffer-pool"
)

// The MemoryManager allows management of memory allocations.
// Memory is allocated:
// 1. When opening / accepting a new stream. This uses the highest priority.
// 2. When trying to increase the stream receive window. This uses a lower priority.
type MemoryManager interface {
// ReserveMemory reserves memory / buffer.
ReserveMemory(size int, prio uint8) error
// ReleaseMemory explicitly releases memory previously reserved with ReserveMemory
ReleaseMemory(size int)
}

type nullMemoryManagerImpl struct{}

func (n nullMemoryManagerImpl) ReserveMemory(size int, prio uint8) error { return nil }
func (n nullMemoryManagerImpl) ReleaseMemory(size int) {}

var nullMemoryManager MemoryManager = &nullMemoryManagerImpl{}

// Session is used to wrap a reliable ordered connection and to
// multiplex it into multiple streams.
type Session struct {
Expand Down Expand Up @@ -47,6 +65,8 @@ type Session struct {
// reader is a buffered reader
reader io.Reader

memoryManager MemoryManager

// pings is used to track inflight pings
pingLock sync.Mutex
pingID uint32
Expand Down Expand Up @@ -100,27 +120,31 @@ type Session struct {
}

// newSession is used to construct a new session
func newSession(config *Config, conn net.Conn, client bool, readBuf int) *Session {
func newSession(config *Config, conn net.Conn, client bool, readBuf int, memoryManager MemoryManager) *Session {
var reader io.Reader = conn
if readBuf > 0 {
reader = bufio.NewReaderSize(reader, readBuf)
}
if memoryManager == nil {
memoryManager = nullMemoryManager
}
s := &Session{
config: config,
client: client,
logger: log.New(config.LogOutput, "", log.LstdFlags),
conn: conn,
reader: reader,
streams: make(map[uint32]*Stream),
inflight: make(map[uint32]struct{}),
synCh: make(chan struct{}, config.AcceptBacklog),
acceptCh: make(chan *Stream, config.AcceptBacklog),
sendCh: make(chan []byte, 64),
pongCh: make(chan uint32, config.PingBacklog),
pingCh: make(chan uint32),
recvDoneCh: make(chan struct{}),
sendDoneCh: make(chan struct{}),
shutdownCh: make(chan struct{}),
config: config,
client: client,
logger: log.New(config.LogOutput, "", log.LstdFlags),
conn: conn,
reader: reader,
streams: make(map[uint32]*Stream),
inflight: make(map[uint32]struct{}),
synCh: make(chan struct{}, config.AcceptBacklog),
acceptCh: make(chan *Stream, config.AcceptBacklog),
sendCh: make(chan []byte, 64),
pongCh: make(chan uint32, config.PingBacklog),
pingCh: make(chan uint32),
recvDoneCh: make(chan struct{}),
sendDoneCh: make(chan struct{}),
shutdownCh: make(chan struct{}),
memoryManager: memoryManager,
}
if client {
s.nextStreamID = 1
Expand Down Expand Up @@ -187,6 +211,10 @@ func (s *Session) OpenStream(ctx context.Context) (*Stream, error) {
return nil, s.shutdownErr
}

if err := s.memoryManager.ReserveMemory(initialStreamWindow, 255); err != nil {
return nil, err
}

GET_ID:
// Get an ID, and check for stream exhaustion
id := atomic.LoadUint32(&s.nextStreamID)
Expand All @@ -198,7 +226,7 @@ GET_ID:
}

// Register the stream
stream := newStream(s, id, streamInit)
stream := newStream(s, id, streamInit, initialStreamWindow)
s.streamLock.Lock()
s.streams[id] = stream
s.inflight[id] = struct{}{}
Expand Down Expand Up @@ -477,20 +505,20 @@ func (s *Session) sendLoop() error {
// FIXME: https://github.com/libp2p/go-libp2p/issues/644
// Write coalescing is disabled for now.

//writer := pool.Writer{W: s.conn}
// writer := pool.Writer{W: s.conn}

//var writeTimeout *time.Timer
//var writeTimeoutCh <-chan time.Time
//if s.config.WriteCoalesceDelay > 0 {
// var writeTimeout *time.Timer
// var writeTimeoutCh <-chan time.Time
// if s.config.WriteCoalesceDelay > 0 {
// writeTimeout = time.NewTimer(s.config.WriteCoalesceDelay)
// defer writeTimeout.Stop()

// writeTimeoutCh = writeTimeout.C
//} else {
// } else {
// ch := make(chan time.Time)
// close(ch)
// writeTimeoutCh = ch
//}
// }

for {
// yield after processing the last message, if we've shutdown.
Expand Down Expand Up @@ -526,7 +554,7 @@ func (s *Session) sendLoop() error {
copy(buf, hdr[:])
case <-s.shutdownCh:
return nil
//default:
// default:
// select {
// case buf = <-s.sendCh:
// case <-s.shutdownCh:
Expand Down Expand Up @@ -591,6 +619,7 @@ func (s *Session) recvLoop() error {
defer close(s.recvDoneCh)
var hdr header
for {
// fmt.Printf("ReadFull from %#v\n", s.reader)
// Read the header
if _, err := io.ReadFull(s.reader, hdr[:]); err != nil {
if err != io.EOF && !strings.Contains(err.Error(), "closed") && !strings.Contains(err.Error(), "reset by peer") {
Expand Down Expand Up @@ -733,7 +762,10 @@ func (s *Session) incomingStream(id uint32) error {
}

// Allocate a new stream
stream := newStream(s, id, streamSYNReceived)
if err := s.memoryManager.ReserveMemory(initialStreamWindow, 255); err != nil {
return err
}
stream := newStream(s, id, streamSYNReceived, initialStreamWindow)

s.streamLock.Lock()
defer s.streamLock.Unlock()
Expand All @@ -744,13 +776,14 @@ func (s *Session) incomingStream(id uint32) error {
if sendErr := s.sendMsg(s.goAway(goAwayProtoErr), nil, nil); sendErr != nil {
s.logger.Printf("[WARN] yamux: failed to send go away: %v", sendErr)
}
s.memoryManager.ReleaseMemory(initialStreamWindow)
return ErrDuplicateStream
}

if s.numIncomingStreams >= s.config.MaxIncomingStreams {
// too many active streams at the same time
s.logger.Printf("[WARN] yamux: MaxIncomingStreams exceeded, forcing stream reset")
delete(s.streams, id)
s.memoryManager.ReleaseMemory(initialStreamWindow)
hdr := encode(typeWindowUpdate, flagRST, id, 0)
return s.sendMsg(hdr, nil, nil)
}
Expand All @@ -766,7 +799,7 @@ func (s *Session) incomingStream(id uint32) error {
default:
// Backlog exceeded! RST the stream
s.logger.Printf("[WARN] yamux: backlog exceeded, forcing stream reset")
delete(s.streams, id)
s.deleteStream(id)
hdr := encode(typeWindowUpdate, flagRST, id, 0)
return s.sendMsg(hdr, nil, nil)
}
Expand All @@ -788,10 +821,19 @@ func (s *Session) closeStream(id uint32) {
if s.client == (id%2 == 0) {
s.numIncomingStreams--
}
delete(s.streams, id)
s.deleteStream(id)
s.streamLock.Unlock()
}

func (s *Session) deleteStream(id uint32) {
str, ok := s.streams[id]
if !ok {
return
}
s.memoryManager.ReleaseMemory(int(str.recvWindow))
delete(s.streams, id)
}

// establishStream is used to mark a stream that was in the
// SYN Sent state as established.
func (s *Session) establishStream(id uint32) {
Expand Down
26 changes: 13 additions & 13 deletions session_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -117,16 +117,16 @@ func testClientServer() (*Session, *Session) {

func testClientServerConfig(conf *Config) (*Session, *Session) {
conn1, conn2 := testConn()
client, _ := Client(conn1, conf)
server, _ := Server(conn2, conf)
client, _ := Client(conn1, conf, nil)
server, _ := Server(conn2, conf, nil)
return client, server
}

func TestClientClient(t *testing.T) {
conf := testConf()
conn1, conn2 := testConn()
client1, _ := Client(conn1, conf)
client2, _ := Client(conn2, conf)
client1, _ := Client(conn1, conf, nil)
client2, _ := Client(conn2, conf, nil)
defer client1.Close()
defer client2.Close()

Expand All @@ -148,8 +148,8 @@ func TestClientClient(t *testing.T) {
func TestServerServer(t *testing.T) {
conf := testConf()
conn1, conn2 := testConn()
server1, _ := Server(conn1, conf)
server2, _ := Server(conn2, conf)
server1, _ := Server(conn1, conf, nil)
server2, _ := Server(conn2, conf, nil)
defer server1.Close()
defer server2.Close()

Expand Down Expand Up @@ -1028,14 +1028,14 @@ func TestKeepAlive_Timeout(t *testing.T) {
clientConf := testConf()
clientConf.ConnectionWriteTimeout = time.Hour // We're testing keep alives, not connection writes
clientConf.EnableKeepAlive = false // Just test one direction, so it's deterministic who hangs up on whom
client, _ := Client(conn1, clientConf)
client, _ := Client(conn1, clientConf, nil)
defer client.Close()

serverLogs := new(logCapture)
serverConf := testConf()
serverConf.LogOutput = serverLogs

server, _ := Server(conn2, serverConf)
server, _ := Server(conn2, serverConf, nil)
defer server.Close()

errCh := make(chan error, 1)
Expand Down Expand Up @@ -1589,7 +1589,7 @@ func TestLotsOfWritesWithStreamDeadline(t *testing.T) {
if err != nil {
t.Fatal(err)
}
defer stream2.Reset() //nolint
defer stream2.Reset() // nolint

// wait for the server to accept the streams.
<-waitCh
Expand Down Expand Up @@ -1701,8 +1701,8 @@ func TestInitialStreamWindow(t *testing.T) {
sconf.InitialStreamWindowSize = randomUint32(initialStreamWindow, maxWindow)

conn1, conn2 := testConn()
client, _ := Client(conn1, cconf)
server, _ := Server(conn2, sconf)
client, _ := Client(conn1, cconf, nil)
server, _ := Server(conn2, sconf, nil)

errChan := make(chan error, 1)
go func() {
Expand Down Expand Up @@ -1736,13 +1736,13 @@ func TestInitialStreamWindow(t *testing.T) {
func TestMaxIncomingStreams(t *testing.T) {
const maxIncomingStreams = 5
conn1, conn2 := testConn()
client, err := Client(conn1, DefaultConfig())
client, err := Client(conn1, DefaultConfig(), nil)
require.NoError(t, err)
defer client.Close()

conf := DefaultConfig()
conf.MaxIncomingStreams = maxIncomingStreams
server, err := Server(conn2, conf)
server, err := Server(conn2, conf, nil)
require.NoError(t, err)
defer server.Close()

Expand Down
10 changes: 6 additions & 4 deletions stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ type Stream struct {

// newStream is used to construct a new stream within
// a given session for an ID
func newStream(session *Session, id uint32, state streamState) *Stream {
func newStream(session *Session, id uint32, state streamState, initialWindow uint32) *Stream {
s := &Stream{
id: id,
session: session,
Expand All @@ -62,7 +62,7 @@ func newStream(session *Session, id uint32, state streamState) *Stream {
// Initialize the recvBuf with initialStreamWindow, not config.InitialStreamWindowSize.
// The peer isn't allowed to send more data than initialStreamWindow until we've sent
// the first window update (which will grant it up to config.InitialStreamWindowSize).
recvBuf: newSegmentedBuffer(initialStreamWindow),
recvBuf: newSegmentedBuffer(initialWindow),
recvWindow: session.config.InitialStreamWindowSize,
epochStart: time.Now(),
recvNotifyCh: make(chan struct{}, 1),
Expand Down Expand Up @@ -225,8 +225,10 @@ func (s *Stream) sendWindowUpdate() error {
recvWindow = min(s.recvWindow*2, s.session.config.MaxStreamWindowSize)
}
if recvWindow > s.recvWindow {
s.recvWindow = recvWindow
_, delta = s.recvBuf.GrowTo(s.recvWindow, true)
if err := s.session.memoryManager.ReserveMemory(int(delta), 128); err == nil {
s.recvWindow = recvWindow
_, delta = s.recvBuf.GrowTo(s.recvWindow, true)
}
}
}
s.epochStart = now
Expand Down

0 comments on commit e7b8807

Please sign in to comment.