diff --git a/connection_reactor.go b/connection_reactor.go index eb5620ca..424ce564 100644 --- a/connection_reactor.go +++ b/connection_reactor.go @@ -156,11 +156,15 @@ func (c *connection) pauseWrite() { // pauseRead removed the monitoring of read events. // pauseRead used in poller func (c *connection) pauseRead() { + // Note that the poller ensure that every fd should read all left data in socket buffer before detach it. + // So the operator mode should never be ophup. switch c.operator.getMode() { case opread: c.operator.Control(PollR2Hup) + c.operator.throttled = true case opreadwrite: c.operator.Control(PollRW2W) + c.operator.throttled = true } } @@ -170,7 +174,9 @@ func (c *connection) resumeRead() { switch c.operator.getMode() { case ophup: c.operator.Control(PollHup2R) + c.operator.throttled = false case opwrite: c.operator.Control(PollW2RW) + c.operator.throttled = false } } diff --git a/connection_test.go b/connection_test.go index 9a5c551a..094ce5b6 100644 --- a/connection_test.go +++ b/connection_test.go @@ -763,3 +763,65 @@ func TestConnectionReadThreshold(t *testing.T) { wg.Wait() } + +func TestConnectionReadThresholdWithClosed(t *testing.T) { + var readThreshold int64 = 1024 * 100 + var opts = &options{} + var trigger = make(chan struct{}) + opts.onRequest = func(ctx context.Context, connection Connection) error { + if int64(connection.Reader().Len()) < readThreshold { + return nil + } + Equal(t, connection.Reader().Len(), int(readThreshold)) + trigger <- struct{}{} // let client send final msg and close + <-trigger // wait for client send and close + + // read non-throttled data + buf, err := connection.Reader().Next(int(readThreshold)) + Equal(t, int64(len(buf)), readThreshold) + MustNil(t, err) + err = connection.Reader().Release() + MustNil(t, err) + t.Logf("read non-throttled data") + + // continue read throttled data + buf, err = connection.Reader().Next(5) + MustNil(t, err) + t.Logf("read throttled data: [%s]", buf) + Equal(t, len(buf), 5) + MustNil(t, err) + err = connection.Reader().Release() + MustNil(t, err) + Equal(t, connection.Reader().Len(), 0) + + _, err = connection.Reader().Next(1) + Assert(t, errors.Is(err, ErrEOF)) + trigger <- struct{}{} + return nil + } + + WithReadBufferThreshold(readThreshold).f(opts) + r, w := GetSysFdPairs() + rconn, wconn := &connection{}, &connection{} + rconn.init(&netFD{fd: r}, opts) + wconn.init(&netFD{fd: w}, opts) + Assert(t, rconn.readBufferThreshold == readThreshold) + + msg := make([]byte, readThreshold) + _, err := wconn.Writer().WriteBinary(msg) + MustNil(t, err) + err = wconn.Writer().Flush() + MustNil(t, err) + + <-trigger + _, err = wconn.Writer().WriteString("hello") + MustNil(t, err) + err = wconn.Writer().Flush() + MustNil(t, err) + t.Logf("flush final msg") + err = wconn.Close() + MustNil(t, err) + trigger <- struct{}{} + + <-trigger +} diff --git a/fd_operator.go b/fd_operator.go index b94e3825..f3fbd827 100644 --- a/fd_operator.go +++ b/fd_operator.go @@ -51,7 +51,8 @@ type FDOperator struct { // poll is the registered location of the file descriptor. poll Poll - mode int32 + mode int32 + throttled bool // private, used by operatorCache next *FDOperator @@ -112,4 +113,5 @@ func (op *FDOperator) reset() { op.Outputs, op.OutputAck = nil, nil op.poll = nil op.mode = 0 + op.throttled = false } diff --git a/netpoll_test.go b/netpoll_test.go index c77c0cca..85a0bef7 100644 --- a/netpoll_test.go +++ b/netpoll_test.go @@ -505,6 +505,73 @@ func TestReadThresholdOption(t *testing.T) { wg.Wait() } +func TestReadThresholdClosed(t *testing.T) { + /* + client => server: 102400 bytes + 5 bytes + client => server: close connection + server cached: 102400 bytes, and throttled + server read: 102400 bytes, and unthrottled + server cached: 5 bytes + server read: 5 bytes + */ + readThreshold := 1024 * 100 + trigger := make(chan struct{}) + msg1 := make([]byte, readThreshold) + msg2 := []byte("hello") + + // server + ln, err := CreateListener("tcp", ":12345") + MustNil(t, err) + svr, _ := NewEventLoop(func(ctx context.Context, connection Connection) error { + if connection.Reader().Len() < readThreshold { + return nil + } + // server read + t.Logf("server reading msg1") + trigger <- struct{}{} // let client send msg2 + <-trigger // ensure client send msg2 and closed + total := 0 + for { + msg, err := connection.Reader().Next(1) + total += len(msg) + if errors.Is(err, ErrEOF) { + break + } + _ = msg + } + Equal(t, total, readThreshold+5) + close(trigger) + return nil + }, WithReadBufferThreshold(int64(readThreshold))) + defer svr.Shutdown(context.Background()) + go func() { + svr.Serve(ln) + }() + time.Sleep(time.Millisecond * 100) + + // client write + dialer := NewDialer(WithReadBufferThreshold(int64(readThreshold))) + cli, err := dialer.DialConnection("tcp", "127.0.0.1:12345", time.Second) + MustNil(t, err) + t.Logf("client writing msg1") + _, err = cli.Writer().WriteBinary(msg1) + MustNil(t, err) + err = cli.Writer().Flush() + MustNil(t, err) + <-trigger + time.Sleep(time.Millisecond * 100) + t.Logf("client writing msg2") + _, err = cli.Writer().WriteBinary(msg2) + MustNil(t, err) + err = cli.Writer().Flush() + MustNil(t, err) + err = cli.Close() + MustNil(t, err) + t.Logf("client closed") + trigger <- struct{}{} + <-trigger +} + func createTestListener(network, address string) (Listener, error) { for { ln, err := CreateListener(network, address) diff --git a/poll_default_bsd.go b/poll_default_bsd.go index a6488cb2..51f3acb9 100644 --- a/poll_default_bsd.go +++ b/poll_default_bsd.go @@ -115,7 +115,8 @@ func (p *defaultPoll) Wait() error { } } if triggerHup { - if triggerRead && operator.Inputs != nil { + // if peer closed with throttled state, we should ensure we read all left data to avoid data loss + if (triggerRead || operator.throttled) && operator.Inputs != nil { var leftRead int // read all left data if peer send and close if leftRead, err = readall(operator, barriers[i]); err != nil && !errors.Is(err, ErrEOF) { diff --git a/poll_default_linux.go b/poll_default_linux.go index 72737370..8da5ae66 100644 --- a/poll_default_linux.go +++ b/poll_default_linux.go @@ -168,7 +168,8 @@ func (p *defaultPoll) handler(events []epollevent) (closed bool) { } } if triggerHup { - if triggerRead && operator.Inputs != nil { + // if peer closed with throttled state, we should ensure we read all left data to avoid data loss + if (triggerRead || operator.throttled) && operator.Inputs != nil { // read all left data if peer send and close var leftRead int // read all left data if peer send and close