diff --git a/connection_errors.go b/connection_errors.go index b08ba668..6541cc3d 100644 --- a/connection_errors.go +++ b/connection_errors.go @@ -35,6 +35,8 @@ const ( ErrEOF = syscall.Errno(0x106) // Write I/O buffer timeout, calling by Connection.Writer ErrWriteTimeout = syscall.Errno(0x107) + // The wait read size large than read threshold + ErrReadOutOfThreshold = syscall.Errno(0x108) ) const ErrnoMask = 0xFF @@ -97,4 +99,5 @@ var errnos = [...]string{ ErrnoMask & ErrUnsupported: "netpoll dose not support", ErrnoMask & ErrEOF: "EOF", ErrnoMask & ErrWriteTimeout: "connection write timeout", + ErrnoMask & ErrReadOutOfThreshold: "connection read size is out of threshold", } diff --git a/connection_impl.go b/connection_impl.go index 77212de6..8412c0fe 100644 --- a/connection_impl.go +++ b/connection_impl.go @@ -45,8 +45,9 @@ type connection struct { outputBuffer *LinkBuffer outputBarrier *barrier supportZeroCopy bool - maxSize int // The maximum size of data between two Release(). - bookSize int // The size of data that can be read at once. + maxSize int // The maximum size of data between two Release(). + bookSize int // The size of data that can be read at once. + readThreshold int64 // The readThreshold of connection max read. } var ( @@ -94,6 +95,12 @@ func (c *connection) SetWriteTimeout(timeout time.Duration) error { return nil } +// SetReadThreshold implements Connection. +func (c *connection) SetReadThreshold(readThreshold int64) error { + c.readThreshold = readThreshold + return nil +} + // ------------------------------------------ implement zero-copy reader ------------------------------------------ // Next implements Connection. @@ -394,28 +401,44 @@ func (c *connection) triggerWrite(err error) { // waitRead will wait full n bytes. func (c *connection) waitRead(n int) (err error) { if n <= c.inputBuffer.Len() { - return nil + goto CLEANUP } + // cannot wait read with an out of threshold size + if c.readThreshold > 0 && int64(n) > c.readThreshold { + // just return error and dont do cleanup + return Exception(ErrReadOutOfThreshold, "wait read") + } + atomic.StoreInt64(&c.waitReadSize, int64(n)) - defer atomic.StoreInt64(&c.waitReadSize, 0) if c.readTimeout > 0 { - return c.waitReadWithTimeout(n) + err = c.waitReadWithTimeout(n) + goto CLEANUP } // wait full n for c.inputBuffer.Len() < n { switch c.status(closing) { case poller: - return Exception(ErrEOF, "wait read") + err = Exception(ErrEOF, "wait read") case user: - return Exception(ErrConnClosed, "wait read") + err = Exception(ErrConnClosed, "wait read") default: err = <-c.readTrigger - if err != nil { - return err - } + } + if err != nil { + goto CLEANUP } } - return nil +CLEANUP: + atomic.StoreInt64(&c.waitReadSize, 0) + if c.readThreshold > 0 && err == nil { + // only resume read when current read size could make newBufferSize < readThreshold + bufferSize := int64(c.inputBuffer.Len()) + newBufferSize := bufferSize - int64(n) + if bufferSize >= c.readThreshold && newBufferSize < c.readThreshold { + c.resumeRead() + } + } + return err } // waitReadWithTimeout will wait full n bytes or until timeout. diff --git a/connection_onevent.go b/connection_onevent.go index 6f055f37..ae8b435e 100644 --- a/connection_onevent.go +++ b/connection_onevent.go @@ -103,6 +103,7 @@ func (c *connection) onPrepare(opts *options) (err error) { c.SetReadTimeout(opts.readTimeout) c.SetWriteTimeout(opts.writeTimeout) c.SetIdleTimeout(opts.idleTimeout) + c.SetReadThreshold(opts.readThreshold) // calling prepare first and then register. if opts.onPrepare != nil { diff --git a/connection_reactor.go b/connection_reactor.go index cd5d717c..74ed08b1 100644 --- a/connection_reactor.go +++ b/connection_reactor.go @@ -104,6 +104,11 @@ func (c *connection) inputAck(n int) (err error) { c.maxSize = mallocMax } + // trigger throttle + if c.readThreshold > 0 && int64(length) >= c.readThreshold { + c.pauseRead() + } + var needTrigger = true if length == n { // first start onRequest needTrigger = c.onRequest() @@ -138,6 +143,29 @@ func (c *connection) outputAck(n int) (err error) { // rw2r removed the monitoring of write events. func (c *connection) rw2r() { - c.operator.Control(PollRW2R) + switch c.operator.getMode() { + case opreadwrite: + c.operator.Control(PollRW2R) + case opwrite: + c.operator.Control(PollW2RW) + } c.triggerWrite(nil) } + +func (c *connection) pauseRead() { + switch c.operator.getMode() { + case opread: + c.operator.Control(PollR2Hup) + case opreadwrite: + c.operator.Control(PollRW2W) + } +} + +func (c *connection) resumeRead() { + switch c.operator.getMode() { + case ophup: + c.operator.Control(PollHup2R) + case opwrite: + c.operator.Control(PollW2RW) + } +} diff --git a/connection_test.go b/connection_test.go index 782e85c2..743a8e39 100644 --- a/connection_test.go +++ b/connection_test.go @@ -675,3 +675,95 @@ func TestConnectionDailTimeoutAndClose(t *testing.T) { wg.Wait() } } + +func TestConnectionReadOutOfThreshold(t *testing.T) { + var readThreshold = 1024 * 100 + var readSize = readThreshold + 1 + var opts = &options{} + var wg sync.WaitGroup + wg.Add(1) + opts.onRequest = func(ctx context.Context, connection Connection) error { + if connection.Reader().Len() < readThreshold { + return nil + } + defer wg.Done() + // read throttled data + _, err := connection.Reader().Next(readSize) + Assert(t, errors.Is(err, ErrReadOutOfThreshold), err) + connection.Close() + return nil + } + + WithReadThreshold(int64(readThreshold)).f(opts) + r, w := GetSysFdPairs() + rconn, wconn := &connection{}, &connection{} + rconn.init(&netFD{fd: r}, opts) + wconn.init(&netFD{fd: w}, opts) + + msg := make([]byte, readThreshold) + _, err := wconn.Writer().WriteBinary(msg) + MustNil(t, err) + err = wconn.Writer().Flush() + MustNil(t, err) + wg.Wait() +} + +func TestConnectionReadThreshold(t *testing.T) { + var readThreshold int64 = 1024 * 100 + var opts = &options{} + var wg sync.WaitGroup + var throttled int32 + wg.Add(1) + opts.onRequest = func(ctx context.Context, connection Connection) error { + if int64(connection.Reader().Len()) < readThreshold { + return nil + } + defer wg.Done() + + atomic.StoreInt32(&throttled, 1) + // check if no more read data when throttled + inbuffered := connection.Reader().Len() + t.Logf("Inbuffered: %d", inbuffered) + time.Sleep(time.Millisecond * 100) + Equal(t, inbuffered, connection.Reader().Len()) + + // read non-throttled data + buf, err := connection.Reader().Next(int(readThreshold)) + Equal(t, int64(len(buf)), readThreshold) + MustNil(t, err) + err = connection.Reader().Release() + MustNil(t, err) + t.Logf("read non-throttled data") + + // continue read throttled data + buf, err = connection.Reader().Next(5) + MustNil(t, err) + t.Logf("read throttled data: [%s]", buf) + Equal(t, len(buf), 5) + MustNil(t, err) + err = connection.Reader().Release() + MustNil(t, err) + Equal(t, connection.Reader().Len(), 0) + return nil + } + + WithReadThreshold(readThreshold).f(opts) + r, w := GetSysFdPairs() + rconn, wconn := &connection{}, &connection{} + rconn.init(&netFD{fd: r}, opts) + wconn.init(&netFD{fd: w}, opts) + Assert(t, rconn.readThreshold == readThreshold) + + msg := make([]byte, readThreshold) + _, err := wconn.Writer().WriteBinary(msg) + MustNil(t, err) + err = wconn.Writer().Flush() + MustNil(t, err) + _, err = wconn.Writer().WriteString("hello") + MustNil(t, err) + err = wconn.Writer().Flush() + MustNil(t, err) + t.Logf("flush final msg") + + wg.Wait() +} diff --git a/docs/guide/guide_cn.md b/docs/guide/guide_cn.md index f9b9a0db..bffa9b29 100644 --- a/docs/guide/guide_cn.md +++ b/docs/guide/guide_cn.md @@ -519,6 +519,26 @@ func callback(connection netpoll.Connection) error { } ``` +## 8. 如何配置连接的读取阈值大小 ? + +Netpoll 默认不会对端发送数据的读取速度有任何限制,每当连接有数据时,Netpoll 会尽可能快地将数据存放在自己的 buffer 中。但有时候可能用户不希望数据过快发送,或者是希望控制服务内存使用量,又或者业务 OnRequest 回调处理速度很慢需要限制发送方速度,此时可以使用 `WithReadThreshold` 来控制读取的最大阈值。 + +### Client 侧使用 + +``` +dialer := netpoll.NewDialer(netpoll.WithReadThreshold(1024 * 1024 * 1024 * 1)) // 1GB +conn, _ = dialer.DialConnection(network, address, timeout) +``` + +### Server 侧使用 + +``` +eventLoop, _ := netpoll.NewEventLoop( + handle, + netpoll.WithReadThreshold(1024 * 1024 * 1024 * 1), // 1GB +) +``` + # 注意事项 ## 1. 错误设置 NumLoops diff --git a/docs/guide/guide_en.md b/docs/guide/guide_en.md index 08c522f3..1cbbcd8d 100644 --- a/docs/guide/guide_en.md +++ b/docs/guide/guide_en.md @@ -558,6 +558,30 @@ func callback(connection netpoll.Connection) error { } ``` +## 8. How to configure the read threshold of the connection? + +By default, Netpoll does not place any limit on the reading speed of data sent by the end. +Whenever there have more data on the connection, Netpoll will read the data into its own buffer as quickly as possible. + +But sometimes users may not want data to be read too quickly, or they want to control the service memory usage, or the user's OnRequest callback processing data very slowly and need to control the peer's send speed. +In this case, you can use `WithReadThreshold` to control the maximum reading threshold. + +### Client side use + +``` +dialer := netpoll.NewDialer(netpoll.WithReadThreshold(1024 * 1024 * 1024 * 1)) // 1GB +conn, _ = dialer.DialConnection(network, address, timeout) +``` + +### Server side use + +``` +eventLoop, _ := netpoll.NewEventLoop( + handle, + netpoll.WithReadThreshold(1024 * 1024 * 1024 * 1), // 1GB +) +``` + # Attention ## 1. Wrong setting of NumLoops diff --git a/fd_operator.go b/fd_operator.go index 1ac843a9..b94e3825 100644 --- a/fd_operator.go +++ b/fd_operator.go @@ -19,6 +19,15 @@ import ( "sync/atomic" ) +const ( + opdetach int32 = -1 + _ int32 = 0 // default op mode, means nothing + opread int32 = 1 + opwrite int32 = 2 + opreadwrite int32 = 3 + ophup int32 = 4 +) + // FDOperator is a collection of operations on file descriptors. type FDOperator struct { // FD is file descriptor, poll will bind when register. @@ -42,8 +51,7 @@ type FDOperator struct { // poll is the registered location of the file descriptor. poll Poll - // protect only detach once - detached int32 + mode int32 // private, used by operatorCache next *FDOperator @@ -52,9 +60,6 @@ type FDOperator struct { } func (op *FDOperator) Control(event PollEvent) error { - if event == PollDetach && atomic.AddInt32(&op.detached, 1) > 1 { - return nil - } return op.poll.Control(op, event) } @@ -62,6 +67,14 @@ func (op *FDOperator) Free() { op.poll.Free(op) } +func (op *FDOperator) getMode() int32 { + return atomic.LoadInt32(&op.mode) +} + +func (op *FDOperator) setMode(mode int32) { + atomic.StoreInt32(&op.mode, mode) +} + func (op *FDOperator) do() (can bool) { return atomic.CompareAndSwapInt32(&op.state, 1, 2) } @@ -98,5 +111,5 @@ func (op *FDOperator) reset() { op.Inputs, op.InputAck = nil, nil op.Outputs, op.OutputAck = nil, nil op.poll = nil - op.detached = 0 + op.mode = 0 } diff --git a/net_dialer.go b/net_dialer.go index 4c4e8dd2..edcafca1 100644 --- a/net_dialer.go +++ b/net_dialer.go @@ -29,13 +29,22 @@ func DialConnection(network, address string, timeout time.Duration) (connection } // NewDialer only support TCP and unix socket now. -func NewDialer() Dialer { - return &dialer{} +func NewDialer(opts ...Option) Dialer { + d := new(dialer) + if len(opts) > 0 { + d.opts = new(options) + for _, opt := range opts { + opt.f(d.opts) + } + } + return d } var defaultDialer = NewDialer() -type dialer struct{} +type dialer struct { + opts *options +} // DialTimeout implements Dialer. func (d *dialer) DialTimeout(network, address string, timeout time.Duration) (net.Conn, error) { @@ -59,7 +68,7 @@ func (d *dialer) DialConnection(network, address string, timeout time.Duration) raddr := &UnixAddr{ UnixAddr: net.UnixAddr{Name: address, Net: network}, } - return DialUnix(network, nil, raddr) + return dialUnix(network, nil, raddr, d.opts) default: return nil, net.UnknownNetworkError(network) } @@ -95,9 +104,9 @@ func (d *dialer) dialTCP(ctx context.Context, network, address string) (connecti tcpAddr.Port = portnum tcpAddr.Zone = ipaddr.Zone if ipaddr.IP != nil && ipaddr.IP.To4() == nil { - connection, err = DialTCP(ctx, "tcp6", nil, tcpAddr) + connection, err = dialTCP(ctx, "tcp6", nil, tcpAddr, d.opts) } else { - connection, err = DialTCP(ctx, "tcp", nil, tcpAddr) + connection, err = dialTCP(ctx, "tcp", nil, tcpAddr, d.opts) } if err == nil { return connection, nil diff --git a/net_tcpsock.go b/net_tcpsock.go index 2c90634b..87fb84eb 100644 --- a/net_tcpsock.go +++ b/net_tcpsock.go @@ -138,23 +138,16 @@ type TCPConnection struct { } // newTCPConnection wraps *TCPConnection. -func newTCPConnection(conn Conn) (connection *TCPConnection, err error) { +func newTCPConnection(conn Conn, opts *options) (connection *TCPConnection, err error) { connection = &TCPConnection{} - err = connection.init(conn, nil) + err = connection.init(conn, opts) if err != nil { return nil, err } return connection, nil } -// DialTCP acts like Dial for TCP networks. -// -// The network must be a TCP network name; see func Dial for details. -// -// If laddr is nil, a local address is automatically chosen. -// If the IP field of raddr is nil or an unspecified IP address, the -// local system is assumed. -func DialTCP(ctx context.Context, network string, laddr, raddr *TCPAddr) (*TCPConnection, error) { +func dialTCP(ctx context.Context, network string, laddr, raddr *TCPAddr, opts *options) (*TCPConnection, error) { switch network { case "tcp", "tcp4", "tcp6": default: @@ -167,14 +160,25 @@ func DialTCP(ctx context.Context, network string, laddr, raddr *TCPAddr) (*TCPCo ctx = context.Background() } sd := &sysDialer{network: network, address: raddr.String()} - c, err := sd.dialTCP(ctx, laddr, raddr) + c, err := sd.dialTCP(ctx, laddr, raddr, opts) if err != nil { return nil, &net.OpError{Op: "dial", Net: network, Source: laddr.opAddr(), Addr: raddr.opAddr(), Err: err} } return c, nil } -func (sd *sysDialer) dialTCP(ctx context.Context, laddr, raddr *TCPAddr) (*TCPConnection, error) { +// DialTCP acts like Dial for TCP networks. +// +// The network must be a TCP network name; see func Dial for details. +// +// If laddr is nil, a local address is automatically chosen. +// If the IP field of raddr is nil or an unspecified IP address, the +// local system is assumed. +func DialTCP(ctx context.Context, network string, laddr, raddr *TCPAddr) (*TCPConnection, error) { + return dialTCP(ctx, network, laddr, raddr, nil) +} + +func (sd *sysDialer) dialTCP(ctx context.Context, laddr, raddr *TCPAddr, opts *options) (*TCPConnection, error) { conn, err := internetSocket(ctx, sd.network, laddr, raddr, syscall.SOCK_STREAM, 0, "dial") // TCP has a rarely used mechanism called a 'simultaneous connection' in @@ -211,7 +215,7 @@ func (sd *sysDialer) dialTCP(ctx context.Context, laddr, raddr *TCPAddr) (*TCPCo if err != nil { return nil, err } - return newTCPConnection(conn) + return newTCPConnection(conn, opts) } func selfConnect(conn *netFD, err error) bool { diff --git a/net_unixsock.go b/net_unixsock.go index c5213a1c..9564dbc7 100644 --- a/net_unixsock.go +++ b/net_unixsock.go @@ -74,41 +74,45 @@ type UnixConnection struct { } // newUnixConnection wraps UnixConnection. -func newUnixConnection(conn Conn) (connection *UnixConnection, err error) { +func newUnixConnection(conn Conn, opts *options) (connection *UnixConnection, err error) { connection = &UnixConnection{} - err = connection.init(conn, nil) + err = connection.init(conn, opts) if err != nil { return nil, err } return connection, nil } -// DialUnix acts like Dial for Unix networks. -// -// The network must be a Unix network name; see func Dial for details. -// -// If laddr is non-nil, it is used as the local address for the -// connection. -func DialUnix(network string, laddr, raddr *UnixAddr) (*UnixConnection, error) { +func dialUnix(network string, laddr, raddr *UnixAddr, opts *options) (*UnixConnection, error) { switch network { case "unix", "unixgram", "unixpacket": default: return nil, &net.OpError{Op: "dial", Net: network, Source: laddr.opAddr(), Addr: raddr.opAddr(), Err: net.UnknownNetworkError(network)} } sd := &sysDialer{network: network, address: raddr.String()} - c, err := sd.dialUnix(context.Background(), laddr, raddr) + c, err := sd.dialUnix(context.Background(), laddr, raddr, opts) if err != nil { return nil, &net.OpError{Op: "dial", Net: network, Source: laddr.opAddr(), Addr: raddr.opAddr(), Err: err} } return c, nil } -func (sd *sysDialer) dialUnix(ctx context.Context, laddr, raddr *UnixAddr) (*UnixConnection, error) { +// DialUnix acts like Dial for Unix networks. +// +// The network must be a Unix network name; see func Dial for details. +// +// If laddr is non-nil, it is used as the local address for the +// connection. +func DialUnix(network string, laddr, raddr *UnixAddr) (*UnixConnection, error) { + return dialUnix(network, laddr, raddr, nil) +} + +func (sd *sysDialer) dialUnix(ctx context.Context, laddr, raddr *UnixAddr, opts *options) (*UnixConnection, error) { conn, err := unixSocket(ctx, sd.network, laddr, raddr, "dial") if err != nil { return nil, err } - return newUnixConnection(conn) + return newUnixConnection(conn, opts) } func unixSocket(ctx context.Context, network string, laddr, raddr sockaddr, mode string) (conn *netFD, err error) { diff --git a/netpoll_options.go b/netpoll_options.go index 023e574c..8da0a040 100644 --- a/netpoll_options.go +++ b/netpoll_options.go @@ -98,16 +98,23 @@ func WithIdleTimeout(timeout time.Duration) Option { }} } +func WithReadThreshold(readThreshold int64) Option { + return Option{func(op *options) { + op.readThreshold = readThreshold + }} +} + // Option . type Option struct { f func(*options) } type options struct { - onPrepare OnPrepare - onConnect OnConnect - onRequest OnRequest - readTimeout time.Duration - writeTimeout time.Duration - idleTimeout time.Duration + onPrepare OnPrepare + onConnect OnConnect + onRequest OnRequest + readTimeout time.Duration + writeTimeout time.Duration + idleTimeout time.Duration + readThreshold int64 } diff --git a/netpoll_test.go b/netpoll_test.go index 0467e879..047f19a2 100644 --- a/netpoll_test.go +++ b/netpoll_test.go @@ -397,6 +397,114 @@ func TestClientWriteAndClose(t *testing.T) { MustNil(t, err) } +func TestReadThresholdOption(t *testing.T) { + /* + client => server: 102400 bytes + 5 bytes + server cached: 102400 bytes, and throttled + server read: 102400 bytes, and unthrottled + server cached: 5 bytes + server read: 5 bytes + server write: 102400 bytes + 5 bytes + client cached: 102400 bytes, and throttled + client read: 102400 bytes, and unthrottled + client cached: 5 bytes + client read: 5 bytes + */ + readThreshold := 1024 * 100 + trigger := make(chan struct{}) + msg1 := make([]byte, readThreshold) + msg2 := []byte("hello") + var wg sync.WaitGroup + + // server + ln, err := CreateListener("tcp", ":12345") + MustNil(t, err) + wg.Add(3) + svr, _ := NewEventLoop(func(ctx context.Context, connection Connection) error { + if connection.Reader().Len() < readThreshold { + return nil + } + go func() { + defer wg.Done() + // server write + t.Logf("server writing msg1") + _, err := connection.Writer().WriteBinary(msg1) + MustNil(t, err) + err = connection.Writer().Flush() + MustNil(t, err) + <-trigger + time.Sleep(time.Millisecond * 100) + t.Logf("server writing msg2") + _, err = connection.Writer().WriteBinary(msg2) + MustNil(t, err) + err = connection.Writer().Flush() + MustNil(t, err) + }() + + // server read + defer wg.Done() + t.Logf("server reading msg1") + trigger <- struct{}{} // let client send msg2 + time.Sleep(time.Millisecond * 100) // ensure client send msg2 + Equal(t, connection.Reader().Len(), readThreshold) + msg, err := connection.Reader().Next(readThreshold) + MustNil(t, err) + Equal(t, len(msg), readThreshold) + t.Logf("server reading msg2") + msg, err = connection.Reader().Next(5) + MustNil(t, err) + Equal(t, len(msg), 5) + + _, err = connection.Reader().Next(1) + Assert(t, errors.Is(err, ErrEOF)) + t.Logf("server closed") + return nil + }, WithReadThreshold(int64(readThreshold))) + defer svr.Shutdown(context.Background()) + go func() { + svr.Serve(ln) + }() + time.Sleep(time.Millisecond * 100) + + // client write + dialer := NewDialer(WithReadThreshold(int64(readThreshold))) + cli, err := dialer.DialConnection("tcp", "127.0.0.1:12345", time.Second) + MustNil(t, err) + go func() { + defer wg.Done() + t.Logf("client writing msg1") + _, err := cli.Writer().WriteBinary(msg1) + MustNil(t, err) + err = cli.Writer().Flush() + MustNil(t, err) + <-trigger + time.Sleep(time.Millisecond * 100) + t.Logf("client writing msg2") + _, err = cli.Writer().WriteBinary(msg2) + MustNil(t, err) + err = cli.Writer().Flush() + MustNil(t, err) + }() + + // client read + trigger <- struct{}{} // let server send msg2 + time.Sleep(time.Millisecond * 100) // ensure server send msg2 + Equal(t, cli.Reader().Len(), readThreshold) + t.Logf("client reading msg1") + msg, err := cli.Reader().Next(readThreshold) + MustNil(t, err) + Equal(t, len(msg), readThreshold) + t.Logf("client reading msg2") + msg, err = cli.Reader().Next(5) + MustNil(t, err) + Equal(t, len(msg), 5) + + err = cli.Close() + MustNil(t, err) + t.Logf("client closed") + wg.Wait() +} + func createTestListener(network, address string) (Listener, error) { for { ln, err := CreateListener(network, address) diff --git a/poll.go b/poll.go index c494ffd6..414f3afa 100644 --- a/poll.go +++ b/poll.go @@ -59,8 +59,17 @@ const ( // PollR2RW is used to monitor writable for FDOperator, // which is only called when the socket write buffer is full. - PollR2RW PollEvent = 0x5 - + PollR2RW PollEvent = 0x4 // PollRW2R is used to remove the writable monitor of FDOperator, generally used with PollR2RW. - PollRW2R PollEvent = 0x6 + PollRW2R PollEvent = 0x5 + + // PollRW2W is used to remove the readable monitor of FDOperator. + PollRW2W PollEvent = 0x6 + // PollW2RW is used to add the readable monitor of FDOperator, generally used with PollRW2W. + PollW2RW PollEvent = 0x7 + + // PollR2Hup is used to remove the readable monitor of FDOperator. + PollR2Hup PollEvent = 0x8 + // PollHup2R is used to add the readable monitor of FDOperator, generally used with PollR2Hup. + PollHup2R PollEvent = 0x9 ) diff --git a/poll_default_bsd.go b/poll_default_bsd.go index 9c8aa8c9..1741dcf2 100644 --- a/poll_default_bsd.go +++ b/poll_default_bsd.go @@ -182,11 +182,14 @@ func (p *defaultPoll) Control(operator *FDOperator, event PollEvent) error { switch event { case PollReadable: operator.inuse() + operator.setMode(opread) evs[0].Filter, evs[0].Flags = syscall.EVFILT_READ, syscall.EV_ADD|syscall.EV_ENABLE case PollWritable: operator.inuse() + operator.setMode(opwrite) evs[0].Filter, evs[0].Flags = syscall.EVFILT_WRITE, syscall.EV_ADD|syscall.EV_ENABLE case PollDetach: + operator.setMode(ophup) if operator.OnWrite != nil { // means WaitWrite finished evs[0].Filter, evs[0].Flags = syscall.EVFILT_WRITE, syscall.EV_DELETE } else { @@ -194,9 +197,28 @@ func (p *defaultPoll) Control(operator *FDOperator, event PollEvent) error { } p.delOperator(operator) case PollR2RW: + operator.setMode(opreadwrite) evs[0].Filter, evs[0].Flags = syscall.EVFILT_WRITE, syscall.EV_ADD|syscall.EV_ENABLE case PollRW2R: + operator.setMode(opread) evs[0].Filter, evs[0].Flags = syscall.EVFILT_WRITE, syscall.EV_DELETE + case PollRW2W: + operator.setMode(opwrite) + evs[0].Filter, evs[0].Flags = syscall.EVFILT_READ, syscall.EV_DELETE + case PollW2RW: + operator.setMode(opreadwrite) + evs[0].Filter, evs[0].Flags = syscall.EVFILT_READ, syscall.EV_ADD|syscall.EV_ENABLE + case PollR2Hup: + operator.setMode(ophup) + // kqueue should syscall twice to delete read and write + evs[0].Filter, evs[0].Flags = syscall.EVFILT_READ, syscall.EV_DELETE + _, err := syscall.Kevent(p.fd, evs, nil, nil) + evs[0].Filter, evs[0].Flags = syscall.EVFILT_WRITE, syscall.EV_DELETE + _, err = syscall.Kevent(p.fd, evs, nil, nil) + return err + case PollHup2R: + operator.setMode(opread) + evs[0].Filter, evs[0].Flags = syscall.EVFILT_READ, syscall.EV_ADD|syscall.EV_ENABLE } _, err := syscall.Kevent(p.fd, evs, nil, nil) return err diff --git a/poll_default_linux.go b/poll_default_linux.go index a0087ee0..fd902fdd 100644 --- a/poll_default_linux.go +++ b/poll_default_linux.go @@ -244,16 +244,37 @@ func (p *defaultPoll) Control(operator *FDOperator, event PollEvent) error { switch event { case PollReadable: // server accept a new connection and wait read operator.inuse() + operator.setMode(opread) op, evt.events = syscall.EPOLL_CTL_ADD, syscall.EPOLLIN|syscall.EPOLLRDHUP|syscall.EPOLLERR case PollWritable: // client create a new connection and wait connect finished operator.inuse() + operator.setMode(opwrite) op, evt.events = syscall.EPOLL_CTL_ADD, EPOLLET|syscall.EPOLLOUT|syscall.EPOLLRDHUP|syscall.EPOLLERR case PollDetach: // deregister + if operator.getMode() == opdetach { + // protect only detach once + return nil + } + operator.setMode(opdetach) p.delOperator(operator) op, evt.events = syscall.EPOLL_CTL_DEL, syscall.EPOLLIN|syscall.EPOLLOUT|syscall.EPOLLRDHUP|syscall.EPOLLERR case PollR2RW: // connection wait read/write + operator.setMode(opreadwrite) op, evt.events = syscall.EPOLL_CTL_MOD, syscall.EPOLLIN|syscall.EPOLLOUT|syscall.EPOLLRDHUP|syscall.EPOLLERR case PollRW2R: // connection wait read + operator.setMode(opread) + op, evt.events = syscall.EPOLL_CTL_MOD, syscall.EPOLLIN|syscall.EPOLLRDHUP|syscall.EPOLLERR + case PollRW2W: + operator.setMode(opwrite) + op, evt.events = syscall.EPOLL_CTL_MOD, syscall.EPOLLOUT|syscall.EPOLLRDHUP|syscall.EPOLLERR + case PollW2RW: + operator.setMode(opreadwrite) + op, evt.events = syscall.EPOLL_CTL_MOD, syscall.EPOLLIN|syscall.EPOLLOUT|syscall.EPOLLRDHUP|syscall.EPOLLERR + case PollR2Hup: + operator.setMode(ophup) + op, evt.events = syscall.EPOLL_CTL_MOD, syscall.EPOLLRDHUP|syscall.EPOLLERR + case PollHup2R: + operator.setMode(opread) op, evt.events = syscall.EPOLL_CTL_MOD, syscall.EPOLLIN|syscall.EPOLLRDHUP|syscall.EPOLLERR } return EpollCtl(p.fd, op, operator.FD, &evt)