From d5b0914fe7aca5ac245f23287e4c19d1ec71c4c5 Mon Sep 17 00:00:00 2001 From: wangzhuowei Date: Thu, 7 Dec 2023 17:29:38 +0800 Subject: [PATCH] feat: add WithReadThreshold API --- connection_errors.go | 17 ++++--- connection_impl.go | 69 ++++++++++++++++--------- connection_onevent.go | 1 + connection_reactor.go | 43 ++++++++++++++-- connection_test.go | 110 ++++++++++++++++++++++++++++++++++++---- docs/guide/guide_cn.md | 20 ++++++++ docs/guide/guide_en.md | 24 +++++++++ eventloop.go | 30 +++++------ fd_operator.go | 25 ++++++--- mux/shard_queue_test.go | 45 +++++++--------- net_dialer.go | 21 +++++--- net_dialer_test.go | 8 +-- net_polldesc_test.go | 8 +-- net_sock.go | 38 +++++++------- net_tcpsock.go | 30 ++++++----- net_unixsock.go | 28 +++++----- netpoll_options.go | 21 +++++--- netpoll_test.go | 108 +++++++++++++++++++++++++++++++++++++++ nocopy.go | 6 +-- poll.go | 16 ++++-- poll_default_bsd.go | 20 ++++++++ poll_default_linux.go | 21 ++++++++ sys_exec.go | 6 ++- 23 files changed, 551 insertions(+), 164 deletions(-) diff --git a/connection_errors.go b/connection_errors.go index b08ba668..85cc4c78 100644 --- a/connection_errors.go +++ b/connection_errors.go @@ -35,6 +35,8 @@ const ( ErrEOF = syscall.Errno(0x106) // Write I/O buffer timeout, calling by Connection.Writer ErrWriteTimeout = syscall.Errno(0x107) + // The wait read size large than read threshold + ErrReadOutOfThreshold = syscall.Errno(0x108) ) const ErrnoMask = 0xFF @@ -90,11 +92,12 @@ func (e *exception) Unwrap() error { // Errors defined in netpoll var errnos = [...]string{ - ErrnoMask & ErrConnClosed: "connection has been closed", - ErrnoMask & ErrReadTimeout: "connection read timeout", - ErrnoMask & ErrDialTimeout: "dial wait timeout", - ErrnoMask & ErrDialNoDeadline: "dial no deadline", - ErrnoMask & ErrUnsupported: "netpoll dose not support", - ErrnoMask & ErrEOF: "EOF", - ErrnoMask & ErrWriteTimeout: "connection write timeout", + ErrnoMask & ErrConnClosed: "connection has been closed", + ErrnoMask & ErrReadTimeout: "connection read timeout", + ErrnoMask & ErrDialTimeout: "dial wait timeout", + ErrnoMask & ErrDialNoDeadline: "dial no deadline", + ErrnoMask & ErrUnsupported: "netpoll dose not support", + ErrnoMask & ErrEOF: "EOF", + ErrnoMask & ErrWriteTimeout: "connection write timeout", + ErrnoMask & ErrReadOutOfThreshold: "connection read size is out of threshold", } diff --git a/connection_impl.go b/connection_impl.go index 77212de6..291de144 100644 --- a/connection_impl.go +++ b/connection_impl.go @@ -33,20 +33,21 @@ type connection struct { netFD onEvent locker - operator *FDOperator - readTimeout time.Duration - readTimer *time.Timer - readTrigger chan error - waitReadSize int64 - writeTimeout time.Duration - writeTimer *time.Timer - writeTrigger chan error - inputBuffer *LinkBuffer - outputBuffer *LinkBuffer - outputBarrier *barrier - supportZeroCopy bool - maxSize int // The maximum size of data between two Release(). - bookSize int // The size of data that can be read at once. + operator *FDOperator + readTimeout time.Duration + readTimer *time.Timer + readTrigger chan error + waitReadSize int64 + writeTimeout time.Duration + writeTimer *time.Timer + writeTrigger chan error + inputBuffer *LinkBuffer + outputBuffer *LinkBuffer + outputBarrier *barrier + supportZeroCopy bool + maxSize int // The maximum size of data between two Release(). + bookSize int // The size of data that can be read at once. + readBufferThreshold int64 // The readBufferThreshold limit the size of connection inputBuffer. In bytes. } var ( @@ -94,6 +95,12 @@ func (c *connection) SetWriteTimeout(timeout time.Duration) error { return nil } +// SetReadBufferThreshold implements Connection. +func (c *connection) SetReadBufferThreshold(threshold int64) error { + c.readBufferThreshold = threshold + return nil +} + // ------------------------------------------ implement zero-copy reader ------------------------------------------ // Next implements Connection. @@ -394,28 +401,44 @@ func (c *connection) triggerWrite(err error) { // waitRead will wait full n bytes. func (c *connection) waitRead(n int) (err error) { if n <= c.inputBuffer.Len() { - return nil + goto CLEANUP } + // cannot wait read with an out of threshold size + if c.readBufferThreshold > 0 && int64(n) > c.readBufferThreshold { + // just return error and dont do cleanup + return Exception(ErrReadOutOfThreshold, "wait read") + } + atomic.StoreInt64(&c.waitReadSize, int64(n)) - defer atomic.StoreInt64(&c.waitReadSize, 0) if c.readTimeout > 0 { - return c.waitReadWithTimeout(n) + err = c.waitReadWithTimeout(n) + goto CLEANUP } // wait full n for c.inputBuffer.Len() < n { switch c.status(closing) { case poller: - return Exception(ErrEOF, "wait read") + err = Exception(ErrEOF, "wait read") case user: - return Exception(ErrConnClosed, "wait read") + err = Exception(ErrConnClosed, "wait read") default: err = <-c.readTrigger - if err != nil { - return err - } + } + if err != nil { + goto CLEANUP } } - return nil +CLEANUP: + atomic.StoreInt64(&c.waitReadSize, 0) + if c.readBufferThreshold > 0 && err == nil { + // only resume read when current read size could make newBufferSize < readBufferThreshold + bufferSize := int64(c.inputBuffer.Len()) + newBufferSize := bufferSize - int64(n) + if bufferSize >= c.readBufferThreshold && newBufferSize < c.readBufferThreshold { + c.resumeRead() + } + } + return err } // waitReadWithTimeout will wait full n bytes or until timeout. diff --git a/connection_onevent.go b/connection_onevent.go index 6f055f37..f2893fa1 100644 --- a/connection_onevent.go +++ b/connection_onevent.go @@ -103,6 +103,7 @@ func (c *connection) onPrepare(opts *options) (err error) { c.SetReadTimeout(opts.readTimeout) c.SetWriteTimeout(opts.writeTimeout) c.SetIdleTimeout(opts.idleTimeout) + c.SetReadBufferThreshold(opts.readBufferThreshold) // calling prepare first and then register. if opts.onPrepare != nil { diff --git a/connection_reactor.go b/connection_reactor.go index cd5d717c..eb5620ca 100644 --- a/connection_reactor.go +++ b/connection_reactor.go @@ -104,6 +104,11 @@ func (c *connection) inputAck(n int) (err error) { c.maxSize = mallocMax } + // trigger throttle + if c.readBufferThreshold > 0 && int64(length) >= c.readBufferThreshold { + c.pauseRead() + } + var needTrigger = true if length == n { // first start onRequest needTrigger = c.onRequest() @@ -117,7 +122,7 @@ func (c *connection) inputAck(n int) (err error) { // outputs implements FDOperator. func (c *connection) outputs(vs [][]byte) (rs [][]byte, supportZeroCopy bool) { if c.outputBuffer.IsEmpty() { - c.rw2r() + c.pauseWrite() return rs, c.supportZeroCopy } rs = c.outputBuffer.GetBytes(vs) @@ -131,13 +136,41 @@ func (c *connection) outputAck(n int) (err error) { c.outputBuffer.Release() } if c.outputBuffer.IsEmpty() { - c.rw2r() + c.pauseWrite() } return nil } -// rw2r removed the monitoring of write events. -func (c *connection) rw2r() { - c.operator.Control(PollRW2R) +// pauseWrite removed the monitoring of write events. +// pauseWrite used in poller +func (c *connection) pauseWrite() { + switch c.operator.getMode() { + case opreadwrite: + c.operator.Control(PollRW2R) + case opwrite: + c.operator.Control(PollW2Hup) + } c.triggerWrite(nil) } + +// pauseRead removed the monitoring of read events. +// pauseRead used in poller +func (c *connection) pauseRead() { + switch c.operator.getMode() { + case opread: + c.operator.Control(PollR2Hup) + case opreadwrite: + c.operator.Control(PollRW2W) + } +} + +// resumeRead add the monitoring of read events. +// resumeRead used by users +func (c *connection) resumeRead() { + switch c.operator.getMode() { + case ophup: + c.operator.Control(PollHup2R) + case opwrite: + c.operator.Control(PollW2RW) + } +} diff --git a/connection_test.go b/connection_test.go index 782e85c2..9a5c551a 100644 --- a/connection_test.go +++ b/connection_test.go @@ -499,18 +499,15 @@ func TestConnDetach(t *testing.T) { func TestParallelShortConnection(t *testing.T) { ln, err := createTestListener("tcp", ":12345") MustNil(t, err) - defer ln.Close() - var received int64 el, err := NewEventLoop(func(ctx context.Context, connection Connection) error { data, err := connection.Reader().Next(connection.Reader().Len()) - if err != nil { - return err - } + Assert(t, err == nil || errors.Is(err, ErrEOF)) atomic.AddInt64(&received, int64(len(data))) - //t.Logf("conn[%s] received: %d, active: %v", connection.RemoteAddr(), len(data), connection.IsActive()) + t.Logf("conn[%s] received: %d, active: %v", connection.RemoteAddr(), len(data), connection.IsActive()) return nil }) + defer el.Shutdown(context.Background()) go func() { el.Serve(ln) }() @@ -536,10 +533,11 @@ func TestParallelShortConnection(t *testing.T) { } wg.Wait() - for atomic.LoadInt64(&received) < int64(totalSize) { - t.Logf("received: %d, except: %d", atomic.LoadInt64(&received), totalSize) + start := time.Now() + for atomic.LoadInt64(&received) < int64(totalSize) && time.Now().Sub(start) < time.Second { time.Sleep(time.Millisecond * 100) } + Equal(t, atomic.LoadInt64(&received), int64(totalSize)) } func TestConnectionServerClose(t *testing.T) { @@ -643,8 +641,6 @@ func TestConnectionServerClose(t *testing.T) { func TestConnectionDailTimeoutAndClose(t *testing.T) { ln, err := createTestListener("tcp", ":12345") MustNil(t, err) - defer ln.Close() - el, err := NewEventLoop( func(ctx context.Context, connection Connection) error { _, err = connection.Reader().Next(connection.Reader().Len()) @@ -668,10 +664,102 @@ func TestConnectionDailTimeoutAndClose(t *testing.T) { go func() { defer wg.Done() conn, err := DialConnection("tcp", ":12345", time.Nanosecond) - Assert(t, err == nil || strings.Contains(err.Error(), "i/o timeout")) + Assert(t, err == nil || strings.Contains(err.Error(), "i/o timeout"), err) _ = conn }() } wg.Wait() } } + +func TestConnectionReadOutOfThreshold(t *testing.T) { + var readThreshold = 1024 * 100 + var readSize = readThreshold + 1 + var opts = &options{} + var wg sync.WaitGroup + wg.Add(1) + opts.onRequest = func(ctx context.Context, connection Connection) error { + if connection.Reader().Len() < readThreshold { + return nil + } + defer wg.Done() + // read throttled data + _, err := connection.Reader().Next(readSize) + Assert(t, errors.Is(err, ErrReadOutOfThreshold), err) + connection.Close() + return nil + } + + WithReadBufferThreshold(int64(readThreshold)).f(opts) + r, w := GetSysFdPairs() + rconn, wconn := &connection{}, &connection{} + rconn.init(&netFD{fd: r}, opts) + wconn.init(&netFD{fd: w}, opts) + + msg := make([]byte, readThreshold) + _, err := wconn.Writer().WriteBinary(msg) + MustNil(t, err) + err = wconn.Writer().Flush() + MustNil(t, err) + wg.Wait() +} + +func TestConnectionReadThreshold(t *testing.T) { + var readThreshold int64 = 1024 * 100 + var opts = &options{} + var wg sync.WaitGroup + var throttled int32 + wg.Add(1) + opts.onRequest = func(ctx context.Context, connection Connection) error { + if int64(connection.Reader().Len()) < readThreshold { + return nil + } + defer wg.Done() + + atomic.StoreInt32(&throttled, 1) + // check if no more read data when throttled + inbuffered := connection.Reader().Len() + t.Logf("Inbuffered: %d", inbuffered) + time.Sleep(time.Millisecond * 100) + Equal(t, inbuffered, connection.Reader().Len()) + + // read non-throttled data + buf, err := connection.Reader().Next(int(readThreshold)) + Equal(t, int64(len(buf)), readThreshold) + MustNil(t, err) + err = connection.Reader().Release() + MustNil(t, err) + t.Logf("read non-throttled data") + + // continue read throttled data + buf, err = connection.Reader().Next(5) + MustNil(t, err) + t.Logf("read throttled data: [%s]", buf) + Equal(t, len(buf), 5) + MustNil(t, err) + err = connection.Reader().Release() + MustNil(t, err) + Equal(t, connection.Reader().Len(), 0) + return nil + } + + WithReadBufferThreshold(readThreshold).f(opts) + r, w := GetSysFdPairs() + rconn, wconn := &connection{}, &connection{} + rconn.init(&netFD{fd: r}, opts) + wconn.init(&netFD{fd: w}, opts) + Assert(t, rconn.readBufferThreshold == readThreshold) + + msg := make([]byte, readThreshold) + _, err := wconn.Writer().WriteBinary(msg) + MustNil(t, err) + err = wconn.Writer().Flush() + MustNil(t, err) + _, err = wconn.Writer().WriteString("hello") + MustNil(t, err) + err = wconn.Writer().Flush() + MustNil(t, err) + t.Logf("flush final msg") + + wg.Wait() +} diff --git a/docs/guide/guide_cn.md b/docs/guide/guide_cn.md index f9b9a0db..bffa9b29 100644 --- a/docs/guide/guide_cn.md +++ b/docs/guide/guide_cn.md @@ -519,6 +519,26 @@ func callback(connection netpoll.Connection) error { } ``` +## 8. 如何配置连接的读取阈值大小 ? + +Netpoll 默认不会对端发送数据的读取速度有任何限制,每当连接有数据时,Netpoll 会尽可能快地将数据存放在自己的 buffer 中。但有时候可能用户不希望数据过快发送,或者是希望控制服务内存使用量,又或者业务 OnRequest 回调处理速度很慢需要限制发送方速度,此时可以使用 `WithReadThreshold` 来控制读取的最大阈值。 + +### Client 侧使用 + +``` +dialer := netpoll.NewDialer(netpoll.WithReadThreshold(1024 * 1024 * 1024 * 1)) // 1GB +conn, _ = dialer.DialConnection(network, address, timeout) +``` + +### Server 侧使用 + +``` +eventLoop, _ := netpoll.NewEventLoop( + handle, + netpoll.WithReadThreshold(1024 * 1024 * 1024 * 1), // 1GB +) +``` + # 注意事项 ## 1. 错误设置 NumLoops diff --git a/docs/guide/guide_en.md b/docs/guide/guide_en.md index 08c522f3..1cbbcd8d 100644 --- a/docs/guide/guide_en.md +++ b/docs/guide/guide_en.md @@ -558,6 +558,30 @@ func callback(connection netpoll.Connection) error { } ``` +## 8. How to configure the read threshold of the connection? + +By default, Netpoll does not place any limit on the reading speed of data sent by the end. +Whenever there have more data on the connection, Netpoll will read the data into its own buffer as quickly as possible. + +But sometimes users may not want data to be read too quickly, or they want to control the service memory usage, or the user's OnRequest callback processing data very slowly and need to control the peer's send speed. +In this case, you can use `WithReadThreshold` to control the maximum reading threshold. + +### Client side use + +``` +dialer := netpoll.NewDialer(netpoll.WithReadThreshold(1024 * 1024 * 1024 * 1)) // 1GB +conn, _ = dialer.DialConnection(network, address, timeout) +``` + +### Server side use + +``` +eventLoop, _ := netpoll.NewEventLoop( + handle, + netpoll.WithReadThreshold(1024 * 1024 * 1024 * 1), // 1GB +) +``` + # Attention ## 1. Wrong setting of NumLoops diff --git a/eventloop.go b/eventloop.go index c9a903c0..333e2833 100644 --- a/eventloop.go +++ b/eventloop.go @@ -54,27 +54,27 @@ type OnPrepare func(connection Connection) context.Context // // An example usage in TCP Proxy scenario: // -// func onConnect(ctx context.Context, upstream netpoll.Connection) context.Context { -// downstream, _ := netpoll.DialConnection("tcp", downstreamAddr, time.Second) -// return context.WithValue(ctx, downstreamKey, downstream) -// } -// func onRequest(ctx context.Context, upstream netpoll.Connection) error { -// downstream := ctx.Value(downstreamKey).(netpoll.Connection) -// } +// func onConnect(ctx context.Context, upstream netpoll.Connection) context.Context { +// downstream, _ := netpoll.DialConnection("tcp", downstreamAddr, time.Second) +// return context.WithValue(ctx, downstreamKey, downstream) +// } +// func onRequest(ctx context.Context, upstream netpoll.Connection) error { +// downstream := ctx.Value(downstreamKey).(netpoll.Connection) +// } type OnConnect func(ctx context.Context, connection Connection) context.Context // OnRequest defines the function for handling connection. When data is sent from the connection peer, // netpoll actively reads the data in LT mode and places it in the connection's input buffer. // Generally, OnRequest starts handling the data in the following way: // -// func OnRequest(ctx context, connection Connection) error { -// input := connection.Reader().Next(n) -// handling input data... -// send, _ := connection.Writer().Malloc(l) -// copy(send, output) -// connection.Flush() -// return nil -// } +// func OnRequest(ctx context, connection Connection) error { +// input := connection.Reader().Next(n) +// handling input data... +// send, _ := connection.Writer().Malloc(l) +// copy(send, output) +// connection.Flush() +// return nil +// } // // OnRequest will run in a separate goroutine and // it is guaranteed that there is one and only one OnRequest running at the same time. diff --git a/fd_operator.go b/fd_operator.go index 1ac843a9..b94e3825 100644 --- a/fd_operator.go +++ b/fd_operator.go @@ -19,6 +19,15 @@ import ( "sync/atomic" ) +const ( + opdetach int32 = -1 + _ int32 = 0 // default op mode, means nothing + opread int32 = 1 + opwrite int32 = 2 + opreadwrite int32 = 3 + ophup int32 = 4 +) + // FDOperator is a collection of operations on file descriptors. type FDOperator struct { // FD is file descriptor, poll will bind when register. @@ -42,8 +51,7 @@ type FDOperator struct { // poll is the registered location of the file descriptor. poll Poll - // protect only detach once - detached int32 + mode int32 // private, used by operatorCache next *FDOperator @@ -52,9 +60,6 @@ type FDOperator struct { } func (op *FDOperator) Control(event PollEvent) error { - if event == PollDetach && atomic.AddInt32(&op.detached, 1) > 1 { - return nil - } return op.poll.Control(op, event) } @@ -62,6 +67,14 @@ func (op *FDOperator) Free() { op.poll.Free(op) } +func (op *FDOperator) getMode() int32 { + return atomic.LoadInt32(&op.mode) +} + +func (op *FDOperator) setMode(mode int32) { + atomic.StoreInt32(&op.mode, mode) +} + func (op *FDOperator) do() (can bool) { return atomic.CompareAndSwapInt32(&op.state, 1, 2) } @@ -98,5 +111,5 @@ func (op *FDOperator) reset() { op.Inputs, op.InputAck = nil, nil op.Outputs, op.OutputAck = nil, nil op.poll = nil - op.detached = 0 + op.mode = 0 } diff --git a/mux/shard_queue_test.go b/mux/shard_queue_test.go index 7a595d21..b0d3f5b4 100644 --- a/mux/shard_queue_test.go +++ b/mux/shard_queue_test.go @@ -19,6 +19,7 @@ package mux import ( "net" + "sync" "testing" "time" @@ -26,28 +27,25 @@ import ( ) func TestShardQueue(t *testing.T) { - var svrConn net.Conn accepted := make(chan struct{}) network, address := "tcp", ":18888" ln, err := net.Listen("tcp", ":18888") MustNil(t, err) - stop := make(chan int, 1) - defer close(stop) + count, pkgsize := 16, 11 + var wg sync.WaitGroup + wg.Add(1) go func() { - var err error - for { - select { - case <-stop: - err = ln.Close() - MustNil(t, err) - return - default: - } - svrConn, err = ln.Accept() - MustNil(t, err) - accepted <- struct{}{} - } + defer wg.Done() + svrConn, err := ln.Accept() + MustNil(t, err) + accepted <- struct{}{} + + total := count * pkgsize + recv := make([]byte, total) + rn, err := svrConn.Read(recv) + MustNil(t, err) + Equal(t, rn, total) }() conn, err := netpoll.DialConnection(network, address, time.Second) @@ -56,8 +54,7 @@ func TestShardQueue(t *testing.T) { // test queue := NewShardQueue(4, conn) - count, pkgsize := 16, 11 - for i := 0; i < int(count); i++ { + for i := 0; i < count; i++ { var getter WriterGetter = func() (buf netpoll.Writer, isNil bool) { buf = netpoll.NewLinkBuffer(pkgsize) buf.Malloc(pkgsize) @@ -68,14 +65,8 @@ func TestShardQueue(t *testing.T) { err = queue.Close() MustNil(t, err) - total := count * pkgsize - recv := make([]byte, total) - rn, err := svrConn.Read(recv) - MustNil(t, err) - Equal(t, rn, total) -} -// TODO: need mock flush -func BenchmarkShardQueue(b *testing.B) { - b.Skip() + wg.Wait() + err = ln.Close() + MustNil(t, err) } diff --git a/net_dialer.go b/net_dialer.go index 4c4e8dd2..edcafca1 100644 --- a/net_dialer.go +++ b/net_dialer.go @@ -29,13 +29,22 @@ func DialConnection(network, address string, timeout time.Duration) (connection } // NewDialer only support TCP and unix socket now. -func NewDialer() Dialer { - return &dialer{} +func NewDialer(opts ...Option) Dialer { + d := new(dialer) + if len(opts) > 0 { + d.opts = new(options) + for _, opt := range opts { + opt.f(d.opts) + } + } + return d } var defaultDialer = NewDialer() -type dialer struct{} +type dialer struct { + opts *options +} // DialTimeout implements Dialer. func (d *dialer) DialTimeout(network, address string, timeout time.Duration) (net.Conn, error) { @@ -59,7 +68,7 @@ func (d *dialer) DialConnection(network, address string, timeout time.Duration) raddr := &UnixAddr{ UnixAddr: net.UnixAddr{Name: address, Net: network}, } - return DialUnix(network, nil, raddr) + return dialUnix(network, nil, raddr, d.opts) default: return nil, net.UnknownNetworkError(network) } @@ -95,9 +104,9 @@ func (d *dialer) dialTCP(ctx context.Context, network, address string) (connecti tcpAddr.Port = portnum tcpAddr.Zone = ipaddr.Zone if ipaddr.IP != nil && ipaddr.IP.To4() == nil { - connection, err = DialTCP(ctx, "tcp6", nil, tcpAddr) + connection, err = dialTCP(ctx, "tcp6", nil, tcpAddr, d.opts) } else { - connection, err = DialTCP(ctx, "tcp", nil, tcpAddr) + connection, err = dialTCP(ctx, "tcp", nil, tcpAddr, d.opts) } if err == nil { return connection, nil diff --git a/net_dialer_test.go b/net_dialer_test.go index 7383fd0d..deca3889 100644 --- a/net_dialer_test.go +++ b/net_dialer_test.go @@ -38,15 +38,14 @@ func TestDialerTCP(t *testing.T) { ln, err := CreateListener("tcp", ":1234") MustNil(t, err) - stop := make(chan int, 1) - defer close(stop) - + stop := make(chan int) go func() { for { select { case <-stop: err := ln.Close() MustNil(t, err) + close(stop) return default: } @@ -61,6 +60,9 @@ func TestDialerTCP(t *testing.T) { MustNil(t, err) MustTrue(t, strings.HasPrefix(conn.LocalAddr().String(), "127.0.0.1:")) Equal(t, conn.RemoteAddr().String(), "127.0.0.1:1234") + + stop <- 0 + <-stop } func TestDialerUnix(t *testing.T) { diff --git a/net_polldesc_test.go b/net_polldesc_test.go index 40804b62..6f379167 100644 --- a/net_polldesc_test.go +++ b/net_polldesc_test.go @@ -30,15 +30,14 @@ func TestRuntimePoll(t *testing.T) { ln, err := CreateListener("tcp", ":1234") MustNil(t, err) - stop := make(chan int, 1) - defer close(stop) - + stop := make(chan int) go func() { for { select { case <-stop: err := ln.Close() MustNil(t, err) + close(stop) return default: } @@ -54,4 +53,7 @@ func TestRuntimePoll(t *testing.T) { MustNil(t, err) conn.Close() } + + stop <- 0 + <-stop } diff --git a/net_sock.go b/net_sock.go index a3d318c7..c6ec98e8 100644 --- a/net_sock.go +++ b/net_sock.go @@ -55,29 +55,29 @@ func internetSocket(ctx context.Context, net string, laddr, raddr sockaddr, soty // address family, both AF_INET and AF_INET6, and a wildcard address // like the following: // -// - A listen for a wildcard communication domain, "tcp" or -// "udp", with a wildcard address: If the platform supports -// both IPv6 and IPv4-mapped IPv6 communication capabilities, -// or does not support IPv4, we use a dual stack, AF_INET6 and -// IPV6_V6ONLY=0, wildcard address listen. The dual stack -// wildcard address listen may fall back to an IPv6-only, -// AF_INET6 and IPV6_V6ONLY=1, wildcard address listen. -// Otherwise we prefer an IPv4-only, AF_INET, wildcard address -// listen. +// - A listen for a wildcard communication domain, "tcp" or +// "udp", with a wildcard address: If the platform supports +// both IPv6 and IPv4-mapped IPv6 communication capabilities, +// or does not support IPv4, we use a dual stack, AF_INET6 and +// IPV6_V6ONLY=0, wildcard address listen. The dual stack +// wildcard address listen may fall back to an IPv6-only, +// AF_INET6 and IPV6_V6ONLY=1, wildcard address listen. +// Otherwise we prefer an IPv4-only, AF_INET, wildcard address +// listen. // -// - A listen for a wildcard communication domain, "tcp" or -// "udp", with an IPv4 wildcard address: same as above. +// - A listen for a wildcard communication domain, "tcp" or +// "udp", with an IPv4 wildcard address: same as above. // -// - A listen for a wildcard communication domain, "tcp" or -// "udp", with an IPv6 wildcard address: same as above. +// - A listen for a wildcard communication domain, "tcp" or +// "udp", with an IPv6 wildcard address: same as above. // -// - A listen for an IPv4 communication domain, "tcp4" or "udp4", -// with an IPv4 wildcard address: We use an IPv4-only, AF_INET, -// wildcard address listen. +// - A listen for an IPv4 communication domain, "tcp4" or "udp4", +// with an IPv4 wildcard address: We use an IPv4-only, AF_INET, +// wildcard address listen. // -// - A listen for an IPv6 communication domain, "tcp6" or "udp6", -// with an IPv6 wildcard address: We use an IPv6-only, AF_INET6 -// and IPV6_V6ONLY=1, wildcard address listen. +// - A listen for an IPv6 communication domain, "tcp6" or "udp6", +// with an IPv6 wildcard address: We use an IPv6-only, AF_INET6 +// and IPV6_V6ONLY=1, wildcard address listen. // // Otherwise guess: If the addresses are IPv4 then returns AF_INET, // or else returns AF_INET6. It also returns a boolean value what diff --git a/net_tcpsock.go b/net_tcpsock.go index 2c90634b..87fb84eb 100644 --- a/net_tcpsock.go +++ b/net_tcpsock.go @@ -138,23 +138,16 @@ type TCPConnection struct { } // newTCPConnection wraps *TCPConnection. -func newTCPConnection(conn Conn) (connection *TCPConnection, err error) { +func newTCPConnection(conn Conn, opts *options) (connection *TCPConnection, err error) { connection = &TCPConnection{} - err = connection.init(conn, nil) + err = connection.init(conn, opts) if err != nil { return nil, err } return connection, nil } -// DialTCP acts like Dial for TCP networks. -// -// The network must be a TCP network name; see func Dial for details. -// -// If laddr is nil, a local address is automatically chosen. -// If the IP field of raddr is nil or an unspecified IP address, the -// local system is assumed. -func DialTCP(ctx context.Context, network string, laddr, raddr *TCPAddr) (*TCPConnection, error) { +func dialTCP(ctx context.Context, network string, laddr, raddr *TCPAddr, opts *options) (*TCPConnection, error) { switch network { case "tcp", "tcp4", "tcp6": default: @@ -167,14 +160,25 @@ func DialTCP(ctx context.Context, network string, laddr, raddr *TCPAddr) (*TCPCo ctx = context.Background() } sd := &sysDialer{network: network, address: raddr.String()} - c, err := sd.dialTCP(ctx, laddr, raddr) + c, err := sd.dialTCP(ctx, laddr, raddr, opts) if err != nil { return nil, &net.OpError{Op: "dial", Net: network, Source: laddr.opAddr(), Addr: raddr.opAddr(), Err: err} } return c, nil } -func (sd *sysDialer) dialTCP(ctx context.Context, laddr, raddr *TCPAddr) (*TCPConnection, error) { +// DialTCP acts like Dial for TCP networks. +// +// The network must be a TCP network name; see func Dial for details. +// +// If laddr is nil, a local address is automatically chosen. +// If the IP field of raddr is nil or an unspecified IP address, the +// local system is assumed. +func DialTCP(ctx context.Context, network string, laddr, raddr *TCPAddr) (*TCPConnection, error) { + return dialTCP(ctx, network, laddr, raddr, nil) +} + +func (sd *sysDialer) dialTCP(ctx context.Context, laddr, raddr *TCPAddr, opts *options) (*TCPConnection, error) { conn, err := internetSocket(ctx, sd.network, laddr, raddr, syscall.SOCK_STREAM, 0, "dial") // TCP has a rarely used mechanism called a 'simultaneous connection' in @@ -211,7 +215,7 @@ func (sd *sysDialer) dialTCP(ctx context.Context, laddr, raddr *TCPAddr) (*TCPCo if err != nil { return nil, err } - return newTCPConnection(conn) + return newTCPConnection(conn, opts) } func selfConnect(conn *netFD, err error) bool { diff --git a/net_unixsock.go b/net_unixsock.go index c5213a1c..9564dbc7 100644 --- a/net_unixsock.go +++ b/net_unixsock.go @@ -74,41 +74,45 @@ type UnixConnection struct { } // newUnixConnection wraps UnixConnection. -func newUnixConnection(conn Conn) (connection *UnixConnection, err error) { +func newUnixConnection(conn Conn, opts *options) (connection *UnixConnection, err error) { connection = &UnixConnection{} - err = connection.init(conn, nil) + err = connection.init(conn, opts) if err != nil { return nil, err } return connection, nil } -// DialUnix acts like Dial for Unix networks. -// -// The network must be a Unix network name; see func Dial for details. -// -// If laddr is non-nil, it is used as the local address for the -// connection. -func DialUnix(network string, laddr, raddr *UnixAddr) (*UnixConnection, error) { +func dialUnix(network string, laddr, raddr *UnixAddr, opts *options) (*UnixConnection, error) { switch network { case "unix", "unixgram", "unixpacket": default: return nil, &net.OpError{Op: "dial", Net: network, Source: laddr.opAddr(), Addr: raddr.opAddr(), Err: net.UnknownNetworkError(network)} } sd := &sysDialer{network: network, address: raddr.String()} - c, err := sd.dialUnix(context.Background(), laddr, raddr) + c, err := sd.dialUnix(context.Background(), laddr, raddr, opts) if err != nil { return nil, &net.OpError{Op: "dial", Net: network, Source: laddr.opAddr(), Addr: raddr.opAddr(), Err: err} } return c, nil } -func (sd *sysDialer) dialUnix(ctx context.Context, laddr, raddr *UnixAddr) (*UnixConnection, error) { +// DialUnix acts like Dial for Unix networks. +// +// The network must be a Unix network name; see func Dial for details. +// +// If laddr is non-nil, it is used as the local address for the +// connection. +func DialUnix(network string, laddr, raddr *UnixAddr) (*UnixConnection, error) { + return dialUnix(network, laddr, raddr, nil) +} + +func (sd *sysDialer) dialUnix(ctx context.Context, laddr, raddr *UnixAddr, opts *options) (*UnixConnection, error) { conn, err := unixSocket(ctx, sd.network, laddr, raddr, "dial") if err != nil { return nil, err } - return newUnixConnection(conn) + return newUnixConnection(conn, opts) } func unixSocket(ctx context.Context, network string, laddr, raddr sockaddr, mode string) (conn *netFD, err error) { diff --git a/netpoll_options.go b/netpoll_options.go index 023e574c..633b8f95 100644 --- a/netpoll_options.go +++ b/netpoll_options.go @@ -98,16 +98,25 @@ func WithIdleTimeout(timeout time.Duration) Option { }} } +// WithReadBufferThreshold sets the max read buffer threshold. +// If connection already read the threshold bytes data, it will stop read more data. +func WithReadBufferThreshold(threshold int64) Option { + return Option{func(op *options) { + op.readBufferThreshold = threshold + }} +} + // Option . type Option struct { f func(*options) } type options struct { - onPrepare OnPrepare - onConnect OnConnect - onRequest OnRequest - readTimeout time.Duration - writeTimeout time.Duration - idleTimeout time.Duration + onPrepare OnPrepare + onConnect OnConnect + onRequest OnRequest + readTimeout time.Duration + writeTimeout time.Duration + idleTimeout time.Duration + readBufferThreshold int64 // bytes } diff --git a/netpoll_test.go b/netpoll_test.go index 0467e879..c77c0cca 100644 --- a/netpoll_test.go +++ b/netpoll_test.go @@ -397,6 +397,114 @@ func TestClientWriteAndClose(t *testing.T) { MustNil(t, err) } +func TestReadThresholdOption(t *testing.T) { + /* + client => server: 102400 bytes + 5 bytes + server cached: 102400 bytes, and throttled + server read: 102400 bytes, and unthrottled + server cached: 5 bytes + server read: 5 bytes + server write: 102400 bytes + 5 bytes + client cached: 102400 bytes, and throttled + client read: 102400 bytes, and unthrottled + client cached: 5 bytes + client read: 5 bytes + */ + readThreshold := 1024 * 100 + trigger := make(chan struct{}) + msg1 := make([]byte, readThreshold) + msg2 := []byte("hello") + var wg sync.WaitGroup + + // server + ln, err := CreateListener("tcp", ":12345") + MustNil(t, err) + wg.Add(3) + svr, _ := NewEventLoop(func(ctx context.Context, connection Connection) error { + if connection.Reader().Len() < readThreshold { + return nil + } + go func() { + defer wg.Done() + // server write + t.Logf("server writing msg1") + _, err := connection.Writer().WriteBinary(msg1) + MustNil(t, err) + err = connection.Writer().Flush() + MustNil(t, err) + <-trigger + time.Sleep(time.Millisecond * 100) + t.Logf("server writing msg2") + _, err = connection.Writer().WriteBinary(msg2) + MustNil(t, err) + err = connection.Writer().Flush() + MustNil(t, err) + }() + + // server read + defer wg.Done() + t.Logf("server reading msg1") + trigger <- struct{}{} // let client send msg2 + time.Sleep(time.Millisecond * 100) // ensure client send msg2 + Equal(t, connection.Reader().Len(), readThreshold) + msg, err := connection.Reader().Next(readThreshold) + MustNil(t, err) + Equal(t, len(msg), readThreshold) + t.Logf("server reading msg2") + msg, err = connection.Reader().Next(5) + MustNil(t, err) + Equal(t, len(msg), 5) + + _, err = connection.Reader().Next(1) + Assert(t, errors.Is(err, ErrEOF)) + t.Logf("server closed") + return nil + }, WithReadBufferThreshold(int64(readThreshold))) + defer svr.Shutdown(context.Background()) + go func() { + svr.Serve(ln) + }() + time.Sleep(time.Millisecond * 100) + + // client write + dialer := NewDialer(WithReadBufferThreshold(int64(readThreshold))) + cli, err := dialer.DialConnection("tcp", "127.0.0.1:12345", time.Second) + MustNil(t, err) + go func() { + defer wg.Done() + t.Logf("client writing msg1") + _, err := cli.Writer().WriteBinary(msg1) + MustNil(t, err) + err = cli.Writer().Flush() + MustNil(t, err) + <-trigger + time.Sleep(time.Millisecond * 100) + t.Logf("client writing msg2") + _, err = cli.Writer().WriteBinary(msg2) + MustNil(t, err) + err = cli.Writer().Flush() + MustNil(t, err) + }() + + // client read + trigger <- struct{}{} // let server send msg2 + time.Sleep(time.Millisecond * 100) // ensure server send msg2 + Equal(t, cli.Reader().Len(), readThreshold) + t.Logf("client reading msg1") + msg, err := cli.Reader().Next(readThreshold) + MustNil(t, err) + Equal(t, len(msg), readThreshold) + t.Logf("client reading msg2") + msg, err = cli.Reader().Next(5) + MustNil(t, err) + Equal(t, len(msg), 5) + + err = cli.Close() + MustNil(t, err) + t.Logf("client closed") + wg.Wait() +} + func createTestListener(network, address string) (Listener, error) { for { ln, err := CreateListener(network, address) diff --git a/nocopy.go b/nocopy.go index 80df5f9b..47ad2c6c 100644 --- a/nocopy.go +++ b/nocopy.go @@ -108,9 +108,9 @@ type Reader interface { // The usage of the design is a two-step operation, first apply for a section of memory, // fill it and then submit. E.g: // -// var buf, _ = Malloc(n) -// buf = append(buf[:0], ...) -// Flush() +// var buf, _ = Malloc(n) +// buf = append(buf[:0], ...) +// Flush() // // Note that it is not recommended to submit self-managed buffers to Writer. // Since the writer is processed asynchronously, if the self-managed buffer is used and recycled after submission, diff --git a/poll.go b/poll.go index c494ffd6..649bf42f 100644 --- a/poll.go +++ b/poll.go @@ -59,8 +59,18 @@ const ( // PollR2RW is used to monitor writable for FDOperator, // which is only called when the socket write buffer is full. - PollR2RW PollEvent = 0x5 - + PollR2RW PollEvent = 0x4 // PollRW2R is used to remove the writable monitor of FDOperator, generally used with PollR2RW. - PollRW2R PollEvent = 0x6 + PollRW2R PollEvent = 0x5 + + // PollRW2W is used to remove the readable monitor of FDOperator. + PollRW2W PollEvent = 0x6 + // PollW2RW is used to add the readable monitor of FDOperator, generally used with PollRW2W. + PollW2RW PollEvent = 0x7 + PollW2Hup PollEvent = 0x8 + + // PollR2Hup is used to remove the readable monitor of FDOperator. + PollR2Hup PollEvent = 0x9 + // PollHup2R is used to add the readable monitor of FDOperator, generally used with PollR2Hup. + PollHup2R PollEvent = 0x10 ) diff --git a/poll_default_bsd.go b/poll_default_bsd.go index 9c8aa8c9..a6488cb2 100644 --- a/poll_default_bsd.go +++ b/poll_default_bsd.go @@ -182,11 +182,14 @@ func (p *defaultPoll) Control(operator *FDOperator, event PollEvent) error { switch event { case PollReadable: operator.inuse() + operator.setMode(opread) evs[0].Filter, evs[0].Flags = syscall.EVFILT_READ, syscall.EV_ADD|syscall.EV_ENABLE case PollWritable: operator.inuse() + operator.setMode(opwrite) evs[0].Filter, evs[0].Flags = syscall.EVFILT_WRITE, syscall.EV_ADD|syscall.EV_ENABLE case PollDetach: + operator.setMode(ophup) if operator.OnWrite != nil { // means WaitWrite finished evs[0].Filter, evs[0].Flags = syscall.EVFILT_WRITE, syscall.EV_DELETE } else { @@ -194,9 +197,26 @@ func (p *defaultPoll) Control(operator *FDOperator, event PollEvent) error { } p.delOperator(operator) case PollR2RW: + operator.setMode(opreadwrite) evs[0].Filter, evs[0].Flags = syscall.EVFILT_WRITE, syscall.EV_ADD|syscall.EV_ENABLE case PollRW2R: + operator.setMode(opread) evs[0].Filter, evs[0].Flags = syscall.EVFILT_WRITE, syscall.EV_DELETE + case PollRW2W: + operator.setMode(opwrite) + evs[0].Filter, evs[0].Flags = syscall.EVFILT_READ, syscall.EV_DELETE + case PollW2RW: + operator.setMode(opreadwrite) + evs[0].Filter, evs[0].Flags = syscall.EVFILT_READ, syscall.EV_ADD|syscall.EV_ENABLE + case PollR2Hup: + operator.setMode(ophup) + evs[0].Filter, evs[0].Flags = syscall.EVFILT_READ, syscall.EV_DELETE + case PollW2Hup: + operator.setMode(ophup) + evs[0].Filter, evs[0].Flags = syscall.EVFILT_WRITE, syscall.EV_DELETE + case PollHup2R: + operator.setMode(opread) + evs[0].Filter, evs[0].Flags = syscall.EVFILT_READ, syscall.EV_ADD|syscall.EV_ENABLE } _, err := syscall.Kevent(p.fd, evs, nil, nil) return err diff --git a/poll_default_linux.go b/poll_default_linux.go index a0087ee0..72737370 100644 --- a/poll_default_linux.go +++ b/poll_default_linux.go @@ -244,16 +244,37 @@ func (p *defaultPoll) Control(operator *FDOperator, event PollEvent) error { switch event { case PollReadable: // server accept a new connection and wait read operator.inuse() + operator.setMode(opread) op, evt.events = syscall.EPOLL_CTL_ADD, syscall.EPOLLIN|syscall.EPOLLRDHUP|syscall.EPOLLERR case PollWritable: // client create a new connection and wait connect finished operator.inuse() + operator.setMode(opwrite) op, evt.events = syscall.EPOLL_CTL_ADD, EPOLLET|syscall.EPOLLOUT|syscall.EPOLLRDHUP|syscall.EPOLLERR case PollDetach: // deregister + if operator.getMode() == opdetach { + // protect only detach once + return nil + } + operator.setMode(opdetach) p.delOperator(operator) op, evt.events = syscall.EPOLL_CTL_DEL, syscall.EPOLLIN|syscall.EPOLLOUT|syscall.EPOLLRDHUP|syscall.EPOLLERR case PollR2RW: // connection wait read/write + operator.setMode(opreadwrite) op, evt.events = syscall.EPOLL_CTL_MOD, syscall.EPOLLIN|syscall.EPOLLOUT|syscall.EPOLLRDHUP|syscall.EPOLLERR case PollRW2R: // connection wait read + operator.setMode(opread) + op, evt.events = syscall.EPOLL_CTL_MOD, syscall.EPOLLIN|syscall.EPOLLRDHUP|syscall.EPOLLERR + case PollRW2W: + operator.setMode(opwrite) + op, evt.events = syscall.EPOLL_CTL_MOD, syscall.EPOLLOUT|syscall.EPOLLRDHUP|syscall.EPOLLERR + case PollW2RW: + operator.setMode(opreadwrite) + op, evt.events = syscall.EPOLL_CTL_MOD, syscall.EPOLLIN|syscall.EPOLLOUT|syscall.EPOLLRDHUP|syscall.EPOLLERR + case PollR2Hup, PollW2Hup: + operator.setMode(ophup) + op, evt.events = syscall.EPOLL_CTL_MOD, syscall.EPOLLRDHUP|syscall.EPOLLERR + case PollHup2R: + operator.setMode(opread) op, evt.events = syscall.EPOLL_CTL_MOD, syscall.EPOLLIN|syscall.EPOLLRDHUP|syscall.EPOLLERR } return EpollCtl(p.fd, op, operator.FD, &evt) diff --git a/sys_exec.go b/sys_exec.go index 1c8e40e4..8a7c5784 100644 --- a/sys_exec.go +++ b/sys_exec.go @@ -94,8 +94,10 @@ func readv(fd int, bs [][]byte, ivs []syscall.Iovec) (n int, err error) { } // TODO: read from sysconf(_SC_IOV_MAX)? The Linux default is -// 1024 and this seems conservative enough for now. Darwin's -// UIO_MAXIOV also seems to be 1024. +// +// 1024 and this seems conservative enough for now. Darwin's +// UIO_MAXIOV also seems to be 1024. +// // iovecs limit length to 2GB(2^31) func iovecs(bs [][]byte, ivs []syscall.Iovec) (iovLen int) { totalLen := 0