Skip to content

Commit

Permalink
fix: add throttled check when connection closed by peer
Browse files Browse the repository at this point in the history
  • Loading branch information
joway committed Dec 22, 2023
1 parent d5b0914 commit b63d10a
Show file tree
Hide file tree
Showing 6 changed files with 150 additions and 7 deletions.
18 changes: 14 additions & 4 deletions connection_reactor.go
Original file line number Diff line number Diff line change
Expand Up @@ -156,21 +156,31 @@ func (c *connection) pauseWrite() {
// pauseRead removed the monitoring of read events.
// pauseRead used in poller
func (c *connection) pauseRead() {
// Note that the poller ensure that every fd should read all left data in socket buffer before detach it.
// So the operator mode should never be ophup.
var changeTo PollEvent
switch c.operator.getMode() {
case opread:
c.operator.Control(PollR2Hup)
changeTo = PollR2Hup
case opreadwrite:
c.operator.Control(PollRW2W)
changeTo = PollRW2W
}
if changeTo > 0 && atomic.CompareAndSwapInt32(&c.operator.throttled, 0, 1) {
c.operator.Control(changeTo)
}
}

// resumeRead add the monitoring of read events.
// resumeRead used by users
func (c *connection) resumeRead() {
var changeTo PollEvent
switch c.operator.getMode() {
case ophup:
c.operator.Control(PollHup2R)
changeTo = PollHup2R
case opwrite:
c.operator.Control(PollW2RW)
changeTo = PollW2RW
}
if changeTo > 0 && atomic.CompareAndSwapInt32(&c.operator.throttled, 1, 0) {
c.operator.Control(changeTo)
}
}
62 changes: 62 additions & 0 deletions connection_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -763,3 +763,65 @@ func TestConnectionReadThreshold(t *testing.T) {

wg.Wait()
}

func TestConnectionReadThresholdWithClosed(t *testing.T) {
var readThreshold int64 = 1024 * 100
var opts = &options{}
var trigger = make(chan struct{})
opts.onRequest = func(ctx context.Context, connection Connection) error {
if int64(connection.Reader().Len()) < readThreshold {
return nil
}
Equal(t, connection.Reader().Len(), int(readThreshold))
trigger <- struct{}{} // let client send final msg and close
<-trigger // wait for client send and close

// read non-throttled data
buf, err := connection.Reader().Next(int(readThreshold))
Equal(t, int64(len(buf)), readThreshold)
MustNil(t, err)
err = connection.Reader().Release()
MustNil(t, err)
t.Logf("read non-throttled data")

// continue read throttled data
buf, err = connection.Reader().Next(5)
MustNil(t, err)
t.Logf("read throttled data: [%s]", buf)
Equal(t, len(buf), 5)
MustNil(t, err)
err = connection.Reader().Release()
MustNil(t, err)
Equal(t, connection.Reader().Len(), 0)

_, err = connection.Reader().Next(1)
Assert(t, errors.Is(err, ErrEOF))
trigger <- struct{}{}
return nil
}

WithReadBufferThreshold(readThreshold).f(opts)
r, w := GetSysFdPairs()
rconn, wconn := &connection{}, &connection{}
rconn.init(&netFD{fd: r}, opts)
wconn.init(&netFD{fd: w}, opts)
Assert(t, rconn.readBufferThreshold == readThreshold)

msg := make([]byte, readThreshold)
_, err := wconn.Writer().WriteBinary(msg)
MustNil(t, err)
err = wconn.Writer().Flush()
MustNil(t, err)

<-trigger
_, err = wconn.Writer().WriteString("hello")
MustNil(t, err)
err = wconn.Writer().Flush()
MustNil(t, err)
t.Logf("flush final msg")
err = wconn.Close()
MustNil(t, err)
trigger <- struct{}{}

<-trigger
}
4 changes: 3 additions & 1 deletion fd_operator.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,8 @@ type FDOperator struct {
// poll is the registered location of the file descriptor.
poll Poll

mode int32
mode int32
throttled int32

// private, used by operatorCache
next *FDOperator
Expand Down Expand Up @@ -112,4 +113,5 @@ func (op *FDOperator) reset() {
op.Outputs, op.OutputAck = nil, nil
op.poll = nil
op.mode = 0
op.throttled = 0
}
67 changes: 67 additions & 0 deletions netpoll_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -505,6 +505,73 @@ func TestReadThresholdOption(t *testing.T) {
wg.Wait()
}

func TestReadThresholdClosed(t *testing.T) {
/*
client => server: 102400 bytes + 5 bytes
client => server: close connection
server cached: 102400 bytes, and throttled
server read: 102400 bytes, and unthrottled
server cached: 5 bytes
server read: 5 bytes
*/
readThreshold := 1024 * 100
trigger := make(chan struct{})
msg1 := make([]byte, readThreshold)
msg2 := []byte("hello")

// server
ln, err := CreateListener("tcp", ":12345")
MustNil(t, err)
svr, _ := NewEventLoop(func(ctx context.Context, connection Connection) error {
if connection.Reader().Len() < readThreshold {
return nil
}
// server read
t.Logf("server reading msg1")
trigger <- struct{}{} // let client send msg2
<-trigger // ensure client send msg2 and closed
total := 0
for {
msg, err := connection.Reader().Next(1)
total += len(msg)
if errors.Is(err, ErrEOF) {
break
}
_ = msg
}
Equal(t, total, readThreshold+5)
close(trigger)
return nil
}, WithReadBufferThreshold(int64(readThreshold)))
defer svr.Shutdown(context.Background())
go func() {
svr.Serve(ln)
}()
time.Sleep(time.Millisecond * 100)

// client write
dialer := NewDialer(WithReadBufferThreshold(int64(readThreshold)))
cli, err := dialer.DialConnection("tcp", "127.0.0.1:12345", time.Second)
MustNil(t, err)
t.Logf("client writing msg1")
_, err = cli.Writer().WriteBinary(msg1)
MustNil(t, err)
err = cli.Writer().Flush()
MustNil(t, err)
<-trigger
time.Sleep(time.Millisecond * 100)
t.Logf("client writing msg2")
_, err = cli.Writer().WriteBinary(msg2)
MustNil(t, err)
err = cli.Writer().Flush()
MustNil(t, err)
err = cli.Close()
MustNil(t, err)
t.Logf("client closed")
trigger <- struct{}{}
<-trigger
}

func createTestListener(network, address string) (Listener, error) {
for {
ln, err := CreateListener(network, address)
Expand Down
3 changes: 2 additions & 1 deletion poll_default_bsd.go
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,8 @@ func (p *defaultPoll) Wait() error {
}
}
if triggerHup {
if triggerRead && operator.Inputs != nil {
// if peer closed with throttled state, we should ensure we read all left data to avoid data loss
if (triggerRead || atomic.LoadInt32(&operator.throttled) > 0) && operator.Inputs != nil {
var leftRead int
// read all left data if peer send and close
if leftRead, err = readall(operator, barriers[i]); err != nil && !errors.Is(err, ErrEOF) {
Expand Down
3 changes: 2 additions & 1 deletion poll_default_linux.go
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,8 @@ func (p *defaultPoll) handler(events []epollevent) (closed bool) {
}
}
if triggerHup {
if triggerRead && operator.Inputs != nil {
// if peer closed with throttled state, we should ensure we read all left data to avoid data loss
if (triggerRead || atomic.LoadInt32(&operator.throttled) > 0) && operator.Inputs != nil {
// read all left data if peer send and close
var leftRead int
// read all left data if peer send and close
Expand Down

0 comments on commit b63d10a

Please sign in to comment.