From af2aae963377df81ee158378e90f649c5fbdec35 Mon Sep 17 00:00:00 2001 From: Andy Pan Date: Mon, 17 Apr 2023 18:42:24 +0800 Subject: [PATCH 1/2] chore: bump ants to v2.7.3 --- go.mod | 2 +- go.sum | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/go.mod b/go.mod index 1ef79846f..401f934c8 100644 --- a/go.mod +++ b/go.mod @@ -1,7 +1,7 @@ module github.com/panjf2000/gnet/v2 require ( - github.com/panjf2000/ants/v2 v2.7.1 + github.com/panjf2000/ants/v2 v2.7.3 github.com/stretchr/testify v1.8.1 github.com/valyala/bytebufferpool v1.0.0 go.uber.org/zap v1.21.0 diff --git a/go.sum b/go.sum index 0c47ca9cf..92fdb0782 100644 --- a/go.sum +++ b/go.sum @@ -10,8 +10,8 @@ github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORN github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= -github.com/panjf2000/ants/v2 v2.7.1 h1:qBy5lfSdbxvrR0yUnZfaEDjf0FlCw4ufsbcsxmE7r+M= -github.com/panjf2000/ants/v2 v2.7.1/go.mod h1:KIBmYG9QQX5U2qzFP/yQJaq/nSb6rahS9iEHkrCMgM8= +github.com/panjf2000/ants/v2 v2.7.3 h1:rHQ0hH0DQvuNUqqlWIMJtkMcDuL1uQAfpX2mIhQ5/s0= +github.com/panjf2000/ants/v2 v2.7.3/go.mod h1:KIBmYG9QQX5U2qzFP/yQJaq/nSb6rahS9iEHkrCMgM8= github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= From 93a1d6ba3b1569843c858603b2176b464f5404c9 Mon Sep 17 00:00:00 2001 From: Andy Pan Date: Sat, 6 May 2023 14:49:36 +0800 Subject: [PATCH 2/2] feat: implement gnet on Windows Fixes #339 --- .github/workflows/test.yml | 12 +- .github/workflows/test_poll_opt.yml | 2 +- README.md | 2 +- README_ZH.md | 2 +- acceptor.go => acceptor_unix.go | 0 acceptor_windows.go | 71 ++++ client_test.go | 8 +- client.go => client_unix.go | 52 +-- client_windows.go | 218 +++++++++++++ connection.go => connection_unix.go | 10 +- connection_windows.go | 483 ++++++++++++++++++++++++++++ engine_stub.go | 4 +- engine.go => engine_unix.go | 117 ++++--- engine_windows.go | 161 ++++++++++ eventloop.go => eventloop_unix.go | 0 eventloop_windows.go | 223 +++++++++++++ gnet.go | 11 +- gnet_test.go | 190 +---------- go.mod | 1 + internal/socket/sock_posix.go | 34 +- listener.go => listener_unix.go | 0 listener_windows.go | 110 +++++++ os_unix_test.go | 204 ++++++++++++ os_windows_test.go | 36 +++ pkg/errors/errors.go | 2 + reactor_default_bsd.go | 49 +-- reactor_default_linux.go | 49 +-- reactor_optimized_bsd.go | 40 +-- reactor_optimized_linux.go | 40 +-- 29 files changed, 1754 insertions(+), 377 deletions(-) rename acceptor.go => acceptor_unix.go (100%) create mode 100644 acceptor_windows.go rename client.go => client_unix.go (87%) create mode 100644 client_windows.go rename connection.go => connection_unix.go (98%) create mode 100644 connection_windows.go rename engine.go => engine_unix.go (74%) create mode 100644 engine_windows.go rename eventloop.go => eventloop_unix.go (100%) create mode 100644 eventloop_windows.go rename listener.go => listener_unix.go (100%) create mode 100644 listener_windows.go create mode 100644 os_unix_test.go create mode 100644 os_windows_test.go diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 3a12937c1..7aef413b8 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -27,6 +27,7 @@ jobs: os: - ubuntu-latest - macos-latest + #- windows-latest name: Run golangci-lint runs-on: ${{ matrix.os }} steps: @@ -42,14 +43,17 @@ jobs: uses: golangci/golangci-lint-action@v3 with: version: v1.51.2 - args: -v -E gofumpt -E gocritic -E misspell -E revive -E godot + args: -v -E gofumpt -E gocritic -E misspell -E revive -E godot --timeout 5m test: needs: lint strategy: fail-fast: false matrix: - go: [1.17, 1.18, 1.19] - os: [ubuntu-latest, macos-latest] + go: ['1.17', '1.18', '1.19', '1.20'] + os: + - ubuntu-latest + - macos-latest + - windows-latest name: Go ${{ matrix.go }} @ ${{ matrix.os }} runs-on: ${{ matrix.os }} steps: @@ -61,7 +65,7 @@ jobs: - name: Setup Go uses: actions/setup-go@v3 with: - go-version: '^1.17' + go-version: ${{ matrix.go }} - name: Print Go environment id: go-env diff --git a/.github/workflows/test_poll_opt.yml b/.github/workflows/test_poll_opt.yml index d4e2509e8..6c28417a6 100644 --- a/.github/workflows/test_poll_opt.yml +++ b/.github/workflows/test_poll_opt.yml @@ -48,7 +48,7 @@ jobs: strategy: fail-fast: false matrix: - go: [1.17, 1.18, 1.19] + go: ['1.17', '1.18', '1.19', '1.20'] os: [ubuntu-latest, macos-latest] name: Go ${{ matrix.go }} @ ${{ matrix.os }} runs-on: ${{ matrix.os }} diff --git a/README.md b/README.md index 2daccbb15..8a647f445 100644 --- a/README.md +++ b/README.md @@ -38,7 +38,7 @@ English | [中文](README_ZH.md) - [x] Supporting two event-driven mechanisms: `epoll` on **Linux** and `kqueue` on **FreeBSD/DragonFly/Darwin** - [x] Flexible ticker event - [x] Implementation of `gnet` Client -- [ ] **Windows** platform support ([gnet v1](https://github.com/panjf2000/gnet/tree/1.x) is available on Windows, v2 not yet) +- [x] **Windows** platform support (Not for production use, only for debugging and testing) - [ ] **TLS** support - [ ] [io_uring](https://kernel.dk/io_uring.pdf) support diff --git a/README_ZH.md b/README_ZH.md index cdbf6a4f6..c24a0ab70 100644 --- a/README_ZH.md +++ b/README_ZH.md @@ -38,7 +38,7 @@ - [x] 支持两种事件驱动机制:**Linux** 里的 `epoll` 以及 **FreeBSD/DragonFly/Darwin** 里的 `kqueue` - [x] 灵活的事件定时器 - [x] 实现 `gnet` 客户端 -- [ ] 支持 **Windows** 平台 ([gnet v1](https://github.com/panjf2000/gnet/tree/1.x) 支持 Windows,v2 暂时不支持) +- [x] 支持 **Windows** 平台 (非生产环境使用,只用来调试和测试) - [ ] 支持 **TLS** - [ ] 支持 [io_uring](https://kernel.dk/io_uring.pdf) diff --git a/acceptor.go b/acceptor_unix.go similarity index 100% rename from acceptor.go rename to acceptor_unix.go diff --git a/acceptor_windows.go b/acceptor_windows.go new file mode 100644 index 000000000..6df850982 --- /dev/null +++ b/acceptor_windows.go @@ -0,0 +1,71 @@ +/* + * Copyright (c) 2023 Andy Pan. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package gnet + +import ( + "net" + "runtime" +) + +func (eng *engine) listen() (err error) { + if eng.opts.LockOSThread { + runtime.LockOSThread() + defer runtime.UnlockOSThread() + } + + defer func() { eng.shutdown(err) }() + + var buffer [0x10000]byte + for { + if eng.ln.pc != nil { + // Read data from UDP socket. + n, addr, e := eng.ln.pc.ReadFrom(buffer[:]) + if e != nil { + err = e + eng.opts.Logger.Errorf("failed to receive data from UDP fd due to error:%v", err) + return + } + + el := eng.lb.next(addr) + c := newUDPConn(el, eng.ln.addr, addr) + el.ch <- packUDPConn(c, buffer[:n]) + } else { + // Accept TCP socket. + tc, e := eng.ln.ln.Accept() + if e != nil { + err = e + eng.opts.Logger.Errorf("Accept() fails due to error: %v", err) + return + } + el := eng.lb.next(tc.RemoteAddr()) + c := newTCPConn(tc, el) + el.ch <- c + go func(c *conn, tc net.Conn, el *eventloop) { + var buffer [0x10000]byte + for { + n, err := tc.Read(buffer[:]) + if err != nil { + el.ch <- &netErr{c, err} + return + } + el.ch <- packTCPConn(c, buffer[:n]) + } + }(c, tc, el) + } + } +} diff --git a/client_test.go b/client_test.go index aef877617..a63dd6fce 100644 --- a/client_test.go +++ b/client_test.go @@ -12,8 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. -//go:build linux || freebsd || dragonfly || darwin -// +build linux freebsd dragonfly darwin +//go:build linux || freebsd || dragonfly || darwin || windows +// +build linux freebsd dragonfly darwin windows package gnet @@ -264,6 +264,7 @@ func (s *testClientServer) OnTraffic(c Conn) (action Action) { } func (s *testClientServer) OnTick() (delay time.Duration, action Action) { + delay = time.Second / 5 if atomic.CompareAndSwapInt32(&s.started, 0, 1) { for i := 0; i < s.nclients; i++ { atomic.AddInt32(&s.clientActive, 1) @@ -278,7 +279,6 @@ func (s *testClientServer) OnTick() (delay time.Duration, action Action) { action = Shutdown return } - delay = time.Second / 5 return } @@ -327,7 +327,7 @@ func startGnetClient(t *testing.T, cli *Client, ev *clientEvents, network, addr ) if netDial { var netConn net.Conn - netConn, err = net.Dial(network, addr) + netConn, err = NetDial(network, addr) require.NoError(t, err) c, err = cli.Enroll(netConn) } else { diff --git a/client.go b/client_unix.go similarity index 87% rename from client.go rename to client_unix.go index 8bf6ef80c..e6d297d99 100644 --- a/client.go +++ b/client_unix.go @@ -22,9 +22,9 @@ import ( "errors" "net" "strconv" - "sync" "syscall" + "golang.org/x/sync/errgroup" "golang.org/x/sys/unix" "github.com/panjf2000/gnet/v2/internal/math" @@ -43,7 +43,7 @@ type Client struct { } // NewClient creates an instance of Client. -func NewClient(eventHandler EventHandler, opts ...Option) (cli *Client, err error) { +func NewClient(eh EventHandler, opts ...Option) (cli *Client, err error) { options := loadOptions(opts...) cli = new(Client) cli.opts = options @@ -62,18 +62,26 @@ func NewClient(eventHandler EventHandler, opts ...Option) (cli *Client, err erro if p, err = netpoll.OpenPoller(); err != nil { return } - eng := new(engine) - eng.opts = options - eng.eventHandler = eventHandler - eng.ln = &listener{network: "udp"} - eng.cond = sync.NewCond(&sync.Mutex{}) + + shutdownCtx, shutdown := context.WithCancel(context.Background()) + eng := engine{ + ln: &listener{network: "udp"}, + opts: options, + eventHandler: eh, + workerPool: struct { + *errgroup.Group + shutdownCtx context.Context + shutdown context.CancelFunc + }{&errgroup.Group{}, shutdownCtx, shutdown}, + } if options.Ticker { - eng.tickerCtx, eng.cancelTicker = context.WithCancel(context.Background()) + eng.ticker.ctx, eng.ticker.cancel = context.WithCancel(context.Background()) + } + el := eventloop{ + ln: eng.ln, + engine: &eng, + poller: p, } - el := new(eventloop) - el.ln = eng.ln - el.engine = eng - el.poller = p rbc := options.ReadBufferCap switch { @@ -97,22 +105,18 @@ func NewClient(eventHandler EventHandler, opts ...Option) (cli *Client, err erro el.buffer = make([]byte, options.ReadBufferCap) el.udpSockets = make(map[int]*conn) el.connections = make(map[int]*conn) - el.eventHandler = eventHandler - cli.el = el + el.eventHandler = eh + cli.el = &el return } // Start starts the client event-loop, handing IO events. func (cli *Client) Start() error { cli.el.eventHandler.OnBoot(Engine{}) - cli.el.engine.wg.Add(1) - go func() { - cli.el.run(cli.opts.LockOSThread) - cli.el.engine.wg.Done() - }() + cli.el.engine.workerPool.Go(cli.el.run) // Start the ticker. if cli.opts.Ticker { - go cli.el.ticker(cli.el.engine.tickerCtx) + go cli.el.ticker(cli.el.engine.ticker.ctx) } return nil } @@ -120,13 +124,13 @@ func (cli *Client) Start() error { // Stop stops the client event-loop. func (cli *Client) Stop() (err error) { logging.Error(cli.el.poller.UrgentTrigger(func(_ interface{}) error { return gerrors.ErrEngineShutdown }, nil)) - cli.el.engine.wg.Wait() - logging.Error(cli.el.poller.Close()) - cli.el.eventHandler.OnShutdown(Engine{}) // Stop the ticker. if cli.opts.Ticker { - cli.el.engine.cancelTicker() + cli.el.engine.ticker.cancel() } + _ = cli.el.engine.workerPool.Wait() + logging.Error(cli.el.poller.Close()) + cli.el.eventHandler.OnShutdown(Engine{}) if cli.logFlush != nil { err = cli.logFlush() } diff --git a/client_windows.go b/client_windows.go new file mode 100644 index 000000000..f59f5a2d6 --- /dev/null +++ b/client_windows.go @@ -0,0 +1,218 @@ +/* + * Copyright (c) 2023 Andy Pan. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package gnet + +import ( + "context" + "net" + "os" + "path/filepath" + "sync" + + "golang.org/x/sync/errgroup" + + errorx "github.com/panjf2000/gnet/v2/pkg/errors" + "github.com/panjf2000/gnet/v2/pkg/logging" +) + +type Client struct { + opts *Options + el *eventloop + logFlush func() error +} + +func NewClient(eh EventHandler, opts ...Option) (cli *Client, err error) { + options := loadOptions(opts...) + cli = &Client{opts: options} + var logger logging.Logger + if options.LogPath != "" { + if logger, cli.logFlush, err = logging.CreateLoggerAsLocalFile(options.LogPath, options.LogLevel); err != nil { + return + } + } else { + logger = logging.GetDefaultLogger() + } + if options.Logger == nil { + options.Logger = logger + } + + shutdownCtx, shutdown := context.WithCancel(context.Background()) + eng := &engine{ + ln: &listener{}, + opts: options, + workerPool: struct { + *errgroup.Group + shutdownCtx context.Context + shutdown context.CancelFunc + }{&errgroup.Group{}, shutdownCtx, shutdown}, + eventHandler: eh, + } + cli.el = &eventloop{ + ch: make(chan interface{}, 1024), + eng: eng, + connections: make(map[*conn]struct{}), + eventHandler: eh, + } + return +} + +func (cli *Client) Start() error { + cli.el.eventHandler.OnBoot(Engine{}) + cli.el.eng.workerPool.Go(cli.el.run) + if cli.opts.Ticker { + cli.el.eng.ticker.ctx, cli.el.eng.ticker.cancel = context.WithCancel(context.Background()) + cli.el.eng.workerPool.Go(func() error { + cli.el.ticker(cli.el.eng.ticker.ctx) + return nil + }) + } + return nil +} + +func (cli *Client) Stop() (err error) { + cli.el.ch <- errorx.ErrEngineShutdown + if cli.opts.Ticker { + cli.el.eng.ticker.cancel() + } + _ = cli.el.eng.workerPool.Wait() + cli.el.eventHandler.OnShutdown(Engine{}) + if cli.logFlush != nil { + err = cli.logFlush() + } + logging.Cleanup() + return +} + +var ( + mu sync.RWMutex + unixAddrDirs = make(map[string]string) +) + +// unixAddr uses os.MkdirTemp to get a name that is unique. +func unixAddr(addr string) string { + // Pass an empty pattern to get a directory name that is as short as possible. + // If we end up with a name longer than the sun_path field in the sockaddr_un + // struct, we won't be able to make the syscall to open the socket. + d, err := os.MkdirTemp("", "") + if err != nil { + panic(err) + } + + tmpAddr := filepath.Join(d, addr) + mu.Lock() + unixAddrDirs[tmpAddr] = d + mu.Unlock() + + return tmpAddr +} + +func (cli *Client) Dial(network, addr string) (Conn, error) { + var ( + c net.Conn + err error + ) + if network == "unix" { + laddr, _ := net.ResolveUnixAddr(network, unixAddr(addr)) + raddr, _ := net.ResolveUnixAddr(network, addr) + c, err = net.DialUnix(network, laddr, raddr) + if err != nil { + return nil, err + } + } else { + c, err = net.Dial(network, addr) + if err != nil { + return nil, err + } + } + return cli.Enroll(c) +} + +func (cli *Client) Enroll(nc net.Conn) (gc Conn, err error) { + switch v := nc.(type) { + case *net.TCPConn: + if cli.opts.TCPNoDelay == TCPNoDelay { + if err = v.SetNoDelay(true); err != nil { + return + } + } + if cli.opts.TCPKeepAlive > 0 { + if err = v.SetKeepAlive(true); err != nil { + return + } + if err = v.SetKeepAlivePeriod(cli.opts.TCPKeepAlive); err != nil { + return + } + } + + c := newTCPConn(nc, cli.el) + cli.el.ch <- c + go func(c *conn, tc net.Conn, el *eventloop) { + var buffer [0x10000]byte + for { + n, err := tc.Read(buffer[:]) + if err != nil { + el.ch <- &netErr{c, err} + return + } + el.ch <- packTCPConn(c, buffer[:n]) + } + }(c, nc, cli.el) + gc = c + case *net.UnixConn: + c := newTCPConn(nc, cli.el) + cli.el.ch <- c + go func(c *conn, uc net.Conn, el *eventloop) { + var buffer [0x10000]byte + for { + n, err := uc.Read(buffer[:]) + if err != nil { + el.ch <- &netErr{c, err} + mu.RLock() + tmpDir := unixAddrDirs[uc.LocalAddr().String()] + mu.RUnlock() + if err := os.RemoveAll(tmpDir); err != nil { + logging.Errorf("failed to remove temporary directory for unix local address: %v", err) + } + return + } + el.ch <- packTCPConn(c, buffer[:n]) + } + }(c, nc, cli.el) + gc = c + case *net.UDPConn: + c := newUDPConn(cli.el, nc.LocalAddr(), nc.RemoteAddr()) + c.rawConn = nc + go func(uc net.Conn, el *eventloop) { + var buffer [0x10000]byte + for { + n, err := uc.Read(buffer[:]) + if err != nil { + return + } + c := newUDPConn(cli.el, uc.LocalAddr(), uc.RemoteAddr()) + c.rawConn = uc + el.ch <- packUDPConn(c, buffer[:n]) + } + }(nc, cli.el) + gc = c + default: + return nil, errorx.ErrUnsupportedProtocol + } + + return +} diff --git a/connection.go b/connection_unix.go similarity index 98% rename from connection.go rename to connection_unix.go index ae76bd650..a001aa29b 100644 --- a/connection.go +++ b/connection_unix.go @@ -415,7 +415,15 @@ func (c *conn) Dup() (fd int, err error) { fd, _, err = netpoll.Dup(c.fd); func (c *conn) SetReadBuffer(bytes int) error { return socket.SetRecvBuffer(c.fd, bytes) } func (c *conn) SetWriteBuffer(bytes int) error { return socket.SetSendBuffer(c.fd, bytes) } func (c *conn) SetLinger(sec int) error { return socket.SetLinger(c.fd, sec) } -func (c *conn) SetNoDelay(noDelay bool) error { return socket.SetNoDelay(c.fd, bool2int(noDelay)) } +func (c *conn) SetNoDelay(noDelay bool) error { + return socket.SetNoDelay(c.fd, func(b bool) int { + if b { + return 1 + } + return 0 + }(noDelay)) +} + func (c *conn) SetKeepAlivePeriod(d time.Duration) error { return socket.SetKeepAlivePeriod(c.fd, int(d.Seconds())) } diff --git a/connection_windows.go b/connection_windows.go new file mode 100644 index 000000000..a909472f5 --- /dev/null +++ b/connection_windows.go @@ -0,0 +1,483 @@ +/* + * Copyright (c) 2023 Andy Pan. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package gnet + +import ( + "errors" + "io" + "net" + "syscall" + "time" + + "golang.org/x/sys/windows" + + "github.com/panjf2000/gnet/v2/pkg/buffer/elastic" + errorx "github.com/panjf2000/gnet/v2/pkg/errors" + bbPool "github.com/panjf2000/gnet/v2/pkg/pool/bytebuffer" +) + +type netErr struct { + c *conn + err error +} + +type tcpConn struct { + c *conn + buf *bbPool.ByteBuffer +} + +type udpConn struct { + c *conn +} + +type conn struct { + ctx interface{} // user-defined context + loop *eventloop // owner event-loop + buffer *bbPool.ByteBuffer // reuse memory of inbound data as a temporary buffer + rawConn net.Conn // original connection + localAddr net.Addr // local server addr + remoteAddr net.Addr // remote peer addr + inboundBuffer elastic.RingBuffer // buffer for data from the peer +} + +func packTCPConn(c *conn, buf []byte) *tcpConn { + tc := &tcpConn{c: c, buf: bbPool.Get()} + _, _ = tc.buf.Write(buf) + return tc +} + +func unpackTCPConn(tc *tcpConn) { + tc.c.buffer = tc.buf +} + +func resetTCPConn(tc *tcpConn) { + bbPool.Put(tc.buf) + tc.c.buffer = nil +} + +func packUDPConn(c *conn, buf []byte) *udpConn { + uc := &udpConn{c} + _, _ = uc.c.buffer.Write(buf) + return uc +} + +func newTCPConn(nc net.Conn, el *eventloop) (c *conn) { + c = &conn{ + loop: el, + rawConn: nc, + } + c.localAddr = c.rawConn.LocalAddr() + c.remoteAddr = c.rawConn.RemoteAddr() + return +} + +func (c *conn) releaseTCP() { + c.ctx = nil + c.localAddr = nil + c.remoteAddr = nil + c.rawConn = nil + c.inboundBuffer.Done() + bbPool.Put(c.buffer) + c.buffer = nil +} + +func newUDPConn(el *eventloop, localAddr, remoteAddr net.Addr) *conn { + return &conn{ + loop: el, + buffer: bbPool.Get(), + localAddr: localAddr, + remoteAddr: remoteAddr, + } +} + +func (c *conn) releaseUDP() { + c.ctx = nil + c.localAddr = nil + bbPool.Put(c.buffer) + c.buffer = nil +} + +func (c *conn) resetBuffer() { + c.buffer.Reset() + c.inboundBuffer.Reset() +} + +// ================================== Non-concurrency-safe API's ================================== + +func (c *conn) Read(p []byte) (n int, err error) { + if c.inboundBuffer.IsEmpty() { + n = copy(p, c.buffer.B) + c.buffer.B = c.buffer.B[n:] + if n == 0 && len(p) > 0 { + err = io.EOF + } + return + } + n, _ = c.inboundBuffer.Read(p) + if n == len(p) { + return + } + m := copy(p[n:], c.buffer.B) + n += m + c.buffer.B = c.buffer.B[m:] + return +} + +func (c *conn) Next(n int) (buf []byte, err error) { + inBufferLen := c.inboundBuffer.Buffered() + if totalLen := inBufferLen + c.buffer.Len(); n > totalLen { + return nil, io.ErrShortBuffer + } else if n <= 0 { + n = totalLen + } + if c.inboundBuffer.IsEmpty() { + buf = c.buffer.B[:n] + c.buffer.B = c.buffer.B[n:] + return + } + head, tail := c.inboundBuffer.Peek(n) + defer c.inboundBuffer.Discard(n) //nolint:errcheck + if len(head) >= n { + return head[:n], err + } + c.loop.cache.Reset() + c.loop.cache.Write(head) + c.loop.cache.Write(tail) + if inBufferLen >= n { + return c.loop.cache.Bytes(), err + } + + remaining := n - inBufferLen + c.loop.cache.Write(c.buffer.B[:remaining]) + c.buffer.B = c.buffer.B[remaining:] + return c.loop.cache.Bytes(), err +} + +func (c *conn) Peek(n int) (buf []byte, err error) { + inBufferLen := c.inboundBuffer.Buffered() + if totalLen := inBufferLen + c.buffer.Len(); n > totalLen { + return nil, io.ErrShortBuffer + } else if n <= 0 { + n = totalLen + } + if c.inboundBuffer.IsEmpty() { + return c.buffer.B[:n], err + } + head, tail := c.inboundBuffer.Peek(n) + if len(head) >= n { + return head[:n], err + } + c.loop.cache.Reset() + c.loop.cache.Write(head) + c.loop.cache.Write(tail) + if inBufferLen >= n { + return c.loop.cache.Bytes(), err + } + + remaining := n - inBufferLen + c.loop.cache.Write(c.buffer.B[:remaining]) + return c.loop.cache.Bytes(), err +} + +func (c *conn) Discard(n int) (int, error) { + inBufferLen := c.inboundBuffer.Buffered() + tempBufferLen := c.buffer.Len() + if inBufferLen+tempBufferLen < n || n <= 0 { + c.resetBuffer() + return inBufferLen + tempBufferLen, nil + } + if c.inboundBuffer.IsEmpty() { + c.buffer.B = c.buffer.B[n:] + return n, nil + } + + discarded, _ := c.inboundBuffer.Discard(n) + if discarded < inBufferLen { + return discarded, nil + } + + remaining := n - inBufferLen + c.buffer.B = c.buffer.B[remaining:] + return n, nil +} + +func (c *conn) Write(p []byte) (int, error) { + if c.rawConn == nil && c.loop.eng.ln.pc == nil { + return 0, errorx.ErrInvalidConn + } + if c.rawConn != nil { + return c.rawConn.Write(p) + } + return c.loop.eng.ln.pc.WriteTo(p, c.remoteAddr) +} + +func (c *conn) Writev(bs [][]byte) (int, error) { + if c.rawConn != nil { + bb := bbPool.Get() + defer bbPool.Put(bb) + for i := range bs { + _, _ = bb.Write(bs[i]) + } + return c.rawConn.Write(bb.Bytes()) + } + return 0, errorx.ErrInvalidConn +} + +func (c *conn) ReadFrom(r io.Reader) (int64, error) { + if c.rawConn != nil { + return io.Copy(c.rawConn, r) + } + return 0, errorx.ErrInvalidConn +} + +func (c *conn) WriteTo(w io.Writer) (n int64, err error) { + if !c.inboundBuffer.IsEmpty() { + if n, err = c.inboundBuffer.WriteTo(w); err != nil { + return + } + } + defer c.buffer.Reset() + return c.buffer.WriteTo(w) +} + +func (c *conn) Flush() error { + return nil +} + +func (c *conn) InboundBuffered() int { + return c.inboundBuffer.Buffered() + c.buffer.Len() +} + +func (c *conn) OutboundBuffered() int { + return 0 +} + +func (*conn) SetDeadline(_ time.Time) error { + return errorx.ErrUnsupportedOp +} + +func (*conn) SetReadDeadline(_ time.Time) error { + return errorx.ErrUnsupportedOp +} + +func (*conn) SetWriteDeadline(_ time.Time) error { + return errorx.ErrUnsupportedOp +} +func (c *conn) Context() interface{} { return c.ctx } +func (c *conn) SetContext(ctx interface{}) { c.ctx = ctx } +func (c *conn) LocalAddr() net.Addr { return c.localAddr } +func (c *conn) RemoteAddr() net.Addr { return c.remoteAddr } + +func (c *conn) Fd() (fd int) { + if c.rawConn == nil { + return -1 + } + + rc, err := c.rawConn.(syscall.Conn).SyscallConn() + if err != nil { + return -1 + } + if err := rc.Control(func(i uintptr) { + fd = int(i) + }); err != nil { + return -1 + } + return +} + +func (c *conn) Dup() (fd int, err error) { + if c.rawConn == nil && c.loop.eng.ln.pc == nil { + return -1, errorx.ErrInvalidConn + } + + var ( + sc syscall.Conn + ok bool + ) + if c.rawConn != nil { + sc, ok = c.rawConn.(syscall.Conn) + } else { + sc, ok = c.loop.eng.ln.pc.(syscall.Conn) + } + + if !ok { + return -1, errors.New("failed to convert net.Conn to syscall.Conn") + } + rc, err := sc.SyscallConn() + if err != nil { + return -1, errors.New("failed to get syscall.RawConn from net.Conn") + } + + var dupHandle windows.Handle + e := rc.Control(func(fd uintptr) { + process := windows.CurrentProcess() + err = windows.DuplicateHandle( + process, + windows.Handle(fd), + process, + &dupHandle, + 0, + true, + windows.DUPLICATE_SAME_ACCESS, + ) + }) + if err != nil { + return -1, err + } + if e != nil { + return -1, e + } + + return int(dupHandle), nil +} + +func (c *conn) SetReadBuffer(bytes int) error { + if c.rawConn == nil && c.loop.eng.ln.pc == nil { + return errorx.ErrInvalidConn + } + + if c.rawConn != nil { + return c.rawConn.(interface{ SetReadBuffer(int) error }).SetReadBuffer(bytes) + } + return c.loop.eng.ln.pc.(interface{ SetReadBuffer(int) error }).SetReadBuffer(bytes) +} + +func (c *conn) SetWriteBuffer(bytes int) error { + if c.rawConn == nil && c.loop.eng.ln.pc == nil { + return errorx.ErrInvalidConn + } + if c.rawConn != nil { + return c.rawConn.(interface{ SetWriteBuffer(int) error }).SetWriteBuffer(bytes) + } + return c.loop.eng.ln.pc.(interface{ SetWriteBuffer(int) error }).SetWriteBuffer(bytes) +} + +func (c *conn) SetLinger(sec int) error { + if c.rawConn == nil { + return errorx.ErrInvalidConn + } + + tc, ok := c.rawConn.(*net.TCPConn) + if !ok { + return errorx.ErrUnsupportedOp + } + return tc.SetLinger(sec) +} + +func (c *conn) SetNoDelay(noDelay bool) error { + if c.rawConn == nil { + return errorx.ErrInvalidConn + } + + tc, ok := c.rawConn.(*net.TCPConn) + if !ok { + return errorx.ErrUnsupportedOp + } + return tc.SetNoDelay(noDelay) +} + +func (c *conn) SetKeepAlivePeriod(d time.Duration) error { + if c.rawConn == nil { + return errorx.ErrInvalidConn + } + + tc, ok := c.rawConn.(*net.TCPConn) + if !ok || d < 0 { + return errorx.ErrUnsupportedOp + } + if err := tc.SetKeepAlive(true); err != nil { + return err + } + if err := tc.SetKeepAlivePeriod(d); err != nil { + _ = tc.SetKeepAlive(false) + return err + } + + return nil +} + +// ==================================== Concurrency-safe API's ==================================== + +func (c *conn) AsyncWrite(buf []byte, cb AsyncCallback) error { + if cb == nil { + cb = func(c Conn, err error) error { return nil } + } + _, err := c.Write(buf) + c.loop.ch <- func() error { + return cb(c, err) + } + return nil +} + +func (c *conn) AsyncWritev(bs [][]byte, cb AsyncCallback) error { + buf := bbPool.Get() + for _, b := range bs { + _, _ = buf.Write(b) + } + return c.AsyncWrite(buf.Bytes(), func(c Conn, err error) error { + defer bbPool.Put(buf) + if cb == nil { + return err + } + return cb(c, err) + }) +} + +func (c *conn) Wake(cb AsyncCallback) error { + if cb == nil { + cb = func(c Conn, err error) error { return nil } + } + c.loop.ch <- func() (err error) { + defer func() { + defer func() { + if err == nil { + err = cb(c, nil) + return + } + _ = cb(c, err) + }() + }() + return c.loop.wake(c) + } + return nil +} + +func (c *conn) Close() error { + c.loop.ch <- func() error { + err := c.loop.close(c, nil) + return err + } + return nil +} + +func (c *conn) CloseWithCallback(cb AsyncCallback) error { + if cb == nil { + cb = func(c Conn, err error) error { return nil } + } + c.loop.ch <- func() (err error) { + defer func() { + if err == nil { + err = cb(c, nil) + return + } + _ = cb(c, err) + }() + return c.loop.close(c, nil) + } + return nil +} diff --git a/engine_stub.go b/engine_stub.go index 62cc62b72..413539e35 100644 --- a/engine_stub.go +++ b/engine_stub.go @@ -13,8 +13,8 @@ // See the License for the specific language governing permissions and // limitations under the License. -//go:build !linux && !freebsd && !dragonfly && !darwin -// +build !linux,!freebsd,!dragonfly,!darwin +//go:build !linux && !freebsd && !dragonfly && !darwin && !windows +// +build !linux,!freebsd,!dragonfly,!darwin,!windows package gnet diff --git a/engine.go b/engine_unix.go similarity index 74% rename from engine.go rename to engine_unix.go index b9161f125..f6f56c7b5 100644 --- a/engine.go +++ b/engine_unix.go @@ -21,54 +21,49 @@ package gnet import ( "context" "runtime" - "sync" "sync/atomic" + "golang.org/x/sync/errgroup" + "github.com/panjf2000/gnet/v2/internal/netpoll" "github.com/panjf2000/gnet/v2/pkg/errors" ) type engine struct { - ln *listener // the listener for accepting new connections - lb loadBalancer // event-loops for handling events - wg sync.WaitGroup // event-loop close WaitGroup - opts *Options // options with engine - once sync.Once // make sure only signalShutdown once - cond *sync.Cond // shutdown signaler - mainLoop *eventloop // main event-loop for accepting connections - inShutdown int32 // whether the engine is in shutdown - tickerCtx context.Context // context for ticker - cancelTicker context.CancelFunc // function to stop the ticker - eventHandler EventHandler // user eventHandler + ln *listener // the listener for accepting new connections + lb loadBalancer // event-loops for handling events + opts *Options // options with engine + mainLoop *eventloop // main event-loop for accepting connections + inShutdown int32 // whether the engine is in shutdown + ticker struct { + ctx context.Context // context for ticker + cancel context.CancelFunc // function to stop the ticker + } + workerPool struct { + *errgroup.Group + + shutdownCtx context.Context + shutdown context.CancelFunc + } + eventHandler EventHandler // user eventHandler } func (eng *engine) isInShutdown() bool { return atomic.LoadInt32(&eng.inShutdown) == 1 } -// waitForShutdown waits for a signal to shut down. -func (eng *engine) waitForShutdown() { - eng.cond.L.Lock() - eng.cond.Wait() - eng.cond.L.Unlock() -} +// shutdown signals the engine to shut down. +func (eng *engine) shutdown(err error) { + if err != nil && err != errors.ErrEngineShutdown { + eng.opts.Logger.Errorf("engine is being shutdown with error: %v", err) + } -// signalShutdown signals the engine to shut down. -func (eng *engine) signalShutdown() { - eng.once.Do(func() { - eng.cond.L.Lock() - eng.cond.Signal() - eng.cond.L.Unlock() - }) + eng.workerPool.shutdown() } func (eng *engine) startEventLoops() { eng.lb.iterate(func(i int, el *eventloop) bool { - eng.wg.Add(1) - go func() { - el.run(eng.opts.LockOSThread) - eng.wg.Done() - }() + eng.workerPool.Go(el.run) return true }) } @@ -82,11 +77,7 @@ func (eng *engine) closeEventLoops() { func (eng *engine) startSubReactors() { eng.lb.iterate(func(i int, el *eventloop) bool { - eng.wg.Add(1) - go func() { - el.activateSubReactor(eng.opts.LockOSThread) - eng.wg.Done() - }() + eng.workerPool.Go(el.activateSubReactor) return true }) } @@ -129,7 +120,10 @@ func (eng *engine) activateEventLoops(numEventLoop int) (err error) { // Start event-loops in background. eng.startEventLoops() - go striker.ticker(eng.tickerCtx) + eng.workerPool.Go(func() error { + striker.ticker(eng.ticker.ctx) + return nil + }) return } @@ -166,18 +160,17 @@ func (eng *engine) activateReactors(numEventLoop int) error { eng.mainLoop = el // Start main reactor in background. - eng.wg.Add(1) - go func() { - el.activateMainReactor(eng.opts.LockOSThread) - eng.wg.Done() - }() + eng.workerPool.Go(el.activateMainReactor) } else { return err } // Start the ticker. if eng.opts.Ticker { - go eng.mainLoop.ticker(eng.tickerCtx) + eng.workerPool.Go(func() error { + eng.mainLoop.ticker(eng.ticker.ctx) + return nil + }) } return nil @@ -193,7 +186,7 @@ func (eng *engine) start(numEventLoop int) error { func (eng *engine) stop(s Engine) { // Wait on a signal for shutdown - eng.waitForShutdown() + <-eng.workerPool.shutdownCtx.Done() eng.eventHandler.OnShutdown(s) @@ -214,8 +207,14 @@ func (eng *engine) stop(s Engine) { } } - // Wait on all loops to complete reading events - eng.wg.Wait() + // Stop the ticker. + if eng.ticker.cancel != nil { + eng.ticker.cancel() + } + + if err := eng.workerPool.Wait(); err != nil { + eng.opts.Logger.Errorf("engine shutdown error: %v", err) + } eng.closeEventLoops() @@ -226,11 +225,6 @@ func (eng *engine) stop(s Engine) { } } - // Stop the ticker. - if eng.opts.Ticker { - eng.cancelTicker() - } - atomic.StoreInt32(&eng.inShutdown, 1) } @@ -244,11 +238,17 @@ func run(eventHandler EventHandler, listener *listener, options *Options, protoA numEventLoop = options.NumEventLoop } - eng := new(engine) - eng.opts = options - eng.eventHandler = eventHandler - eng.ln = listener - + shutdownCtx, shutdown := context.WithCancel(context.Background()) + eng := engine{ + ln: listener, + opts: options, + workerPool: struct { + *errgroup.Group + shutdownCtx context.Context + shutdown context.CancelFunc + }{&errgroup.Group{}, shutdownCtx, shutdown}, + eventHandler: eventHandler, + } switch options.LB { case RoundRobin: eng.lb = new(roundRobinLoadBalancer) @@ -258,12 +258,11 @@ func run(eventHandler EventHandler, listener *listener, options *Options, protoA eng.lb = new(sourceAddrHashLoadBalancer) } - eng.cond = sync.NewCond(&sync.Mutex{}) if eng.opts.Ticker { - eng.tickerCtx, eng.cancelTicker = context.WithCancel(context.Background()) + eng.ticker.ctx, eng.ticker.cancel = context.WithCancel(context.Background()) } - e := Engine{eng} + e := Engine{&eng} switch eng.eventHandler.OnBoot(e) { case None: case Shutdown: @@ -277,7 +276,7 @@ func run(eventHandler EventHandler, listener *listener, options *Options, protoA } defer eng.stop(e) - allEngines.Store(protoAddr, eng) + allEngines.Store(protoAddr, &eng) return nil } diff --git a/engine_windows.go b/engine_windows.go new file mode 100644 index 000000000..afdd737fe --- /dev/null +++ b/engine_windows.go @@ -0,0 +1,161 @@ +/* + * Copyright (c) 2023 Andy Pan. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package gnet + +import ( + "context" + "runtime" + "sync/atomic" + + "golang.org/x/sync/errgroup" + + errorx "github.com/panjf2000/gnet/v2/pkg/errors" +) + +type engine struct { + ln *listener + lb loadBalancer // event-loops for handling events + opts *Options // options with engine + ticker struct { + ctx context.Context + cancel context.CancelFunc + } + inShutdown int32 // whether the engine is in shutdown + workerPool struct { + *errgroup.Group + + shutdownCtx context.Context + shutdown context.CancelFunc + } + eventHandler EventHandler // user eventHandler +} + +func (eng *engine) isInShutdown() bool { + return atomic.LoadInt32(&eng.inShutdown) == 1 +} + +// shutdown signals the engine to shut down. +func (eng *engine) shutdown(err error) { + if err != nil && err != errorx.ErrEngineShutdown { + eng.opts.Logger.Errorf("engine is being shutdown with error: %v", err) + } + eng.workerPool.shutdown() +} + +func (eng *engine) start(numEventLoop int) error { + for i := 0; i < numEventLoop; i++ { + el := eventloop{ + ch: make(chan interface{}, 1024), + idx: i, + eng: eng, + connections: make(map[*conn]struct{}), + eventHandler: eng.eventHandler, + } + eng.lb.register(&el) + eng.workerPool.Go(el.run) + if i == 0 && eng.opts.Ticker { + eng.workerPool.Go(func() error { + el.ticker(eng.ticker.ctx) + return nil + }) + } + } + + eng.workerPool.Go(eng.listen) + + return nil +} + +func (eng *engine) stop(engine Engine) error { + <-eng.workerPool.shutdownCtx.Done() + + eng.opts.Logger.Infof("engine is being shutdown...") + eng.eventHandler.OnShutdown(engine) + + eng.ln.close() + + eng.lb.iterate(func(i int, el *eventloop) bool { + el.ch <- errorx.ErrEngineShutdown + return true + }) + + if eng.ticker.cancel != nil { + eng.ticker.cancel() + } + + if err := eng.workerPool.Wait(); err != nil { + eng.opts.Logger.Errorf("engine shutdown error: %v", err) + } + + atomic.StoreInt32(&eng.inShutdown, 1) + + return nil +} + +func run(eventHandler EventHandler, listener *listener, options *Options, protoAddr string) error { + // Figure out the proper number of event-loops/goroutines to run. + numEventLoop := 1 + if options.Multicore { + numEventLoop = runtime.NumCPU() + } + if options.NumEventLoop > 0 { + numEventLoop = options.NumEventLoop + } + + shutdownCtx, shutdown := context.WithCancel(context.Background()) + eng := engine{ + opts: options, + eventHandler: eventHandler, + ln: listener, + workerPool: struct { + *errgroup.Group + shutdownCtx context.Context + shutdown context.CancelFunc + }{&errgroup.Group{}, shutdownCtx, shutdown}, + } + + switch options.LB { + case RoundRobin: + eng.lb = new(roundRobinLoadBalancer) + case LeastConnections: + eng.lb = new(leastConnectionsLoadBalancer) + case SourceAddrHash: + eng.lb = new(sourceAddrHashLoadBalancer) + } + + if options.Ticker { + eng.ticker.ctx, eng.ticker.cancel = context.WithCancel(context.Background()) + } + + engine := Engine{eng: &eng} + switch eventHandler.OnBoot(engine) { + case None: + case Shutdown: + return nil + } + + if err := eng.start(numEventLoop); err != nil { + eng.opts.Logger.Errorf("gnet engine is stopping with error: %v", err) + return err + } + defer eng.stop(engine) //nolint:errcheck + + allEngines.Store(protoAddr, &eng) + + return nil +} diff --git a/eventloop.go b/eventloop_unix.go similarity index 100% rename from eventloop.go rename to eventloop_unix.go diff --git a/eventloop_windows.go b/eventloop_windows.go new file mode 100644 index 000000000..7d8accb96 --- /dev/null +++ b/eventloop_windows.go @@ -0,0 +1,223 @@ +/* + * Copyright (c) 2023 Andy Pan. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package gnet + +import ( + "bytes" + "context" + "runtime" + "strings" + "sync/atomic" + "time" + + "github.com/panjf2000/gnet/v2/pkg/errors" + "github.com/panjf2000/gnet/v2/pkg/logging" +) + +type eventloop struct { + ch chan interface{} // channel for event-loop + idx int // index of event-loop in event-loops + eng *engine // engine in loop + cache bytes.Buffer // temporary buffer for scattered bytes + connCount int32 // number of active connections in event-loop + connections map[*conn]struct{} // TCP connection map: fd -> conn + eventHandler EventHandler // user eventHandler +} + +func (el *eventloop) getLogger() logging.Logger { + return el.eng.opts.Logger +} + +func (el *eventloop) addConn(delta int32) { + atomic.AddInt32(&el.connCount, delta) +} + +func (el *eventloop) loadConn() int32 { + return atomic.LoadInt32(&el.connCount) +} + +func (el *eventloop) run() (err error) { + defer func() { + el.eng.shutdown(err) + for c := range el.connections { + _ = el.close(c, nil) + } + }() + + if el.eng.opts.LockOSThread { + runtime.LockOSThread() + defer runtime.UnlockOSThread() + } + + for i := range el.ch { + switch v := i.(type) { + case error: + err = v + case *netErr: + err = el.close(v.c, v.err) + case *conn: + err = el.open(v) + case *tcpConn: + unpackTCPConn(v) + err = el.read(v.c) + resetTCPConn(v) + case *udpConn: + err = el.readUDP(v.c) + case func() error: + err = v() + } + + if err == errors.ErrEngineShutdown { + el.getLogger().Debugf("event-loop(%d) is exiting in terms of the demand from user, %v", el.idx, err) + break + } else if err != nil { + el.getLogger().Debugf("event-loop(%d) got a nonlethal error: %v", el.idx, err) + } + } + + return nil +} + +func (el *eventloop) open(c *conn) error { + el.connections[c] = struct{}{} + el.addConn(1) + + out, action := el.eventHandler.OnOpen(c) + if out != nil { + if _, err := c.rawConn.Write(out); err != nil { + return err + } + } + + return el.handleAction(c, action) +} + +func (el *eventloop) read(c *conn) error { + if _, ok := el.connections[c]; !ok { + return nil // ignore stale wakes. + } + action := el.eventHandler.OnTraffic(c) + switch action { + case None: + case Close: + return el.close(c, nil) + case Shutdown: + return errors.ErrEngineShutdown + } + _, _ = c.inboundBuffer.Write(c.buffer.B) + c.buffer.Reset() + + return nil +} + +func (el *eventloop) readUDP(c *conn) error { + action := el.eventHandler.OnTraffic(c) + if action == Shutdown { + return errors.ErrEngineShutdown + } + c.releaseUDP() + return nil +} + +func (el *eventloop) ticker(ctx context.Context) { + if el == nil { + return + } + var ( + action Action + delay time.Duration + timer *time.Timer + ) + defer func() { + if timer != nil { + timer.Stop() + } + }() + var shutdown bool + for { + delay, action = el.eventHandler.OnTick() + switch action { + case None: + case Shutdown: + if !shutdown { + shutdown = true + el.ch <- errors.ErrEngineShutdown + el.getLogger().Debugf("stopping ticker in event-loop(%d) from Tick()", el.idx) + } + } + if timer == nil { + timer = time.NewTimer(delay) + } else { + timer.Reset(delay) + } + select { + case <-ctx.Done(): + el.getLogger().Debugf("stopping ticker in event-loop(%d) from Server, error:%v", el.idx, ctx.Err()) + return + case <-timer.C: + } + } +} + +func (el *eventloop) wake(c *conn) error { + if _, ok := el.connections[c]; !ok { + return nil // ignore stale wakes. + } + action := el.eventHandler.OnTraffic(c) + return el.handleAction(c, action) +} + +func (el *eventloop) close(c *conn, err error) error { + if addr := c.localAddr; addr != nil && strings.HasPrefix(addr.Network(), "udp") { + action := el.eventHandler.OnClose(c, err) + if c.rawConn != nil { + if err := c.rawConn.Close(); err != nil { + el.getLogger().Errorf("failed to close connection(%s), error:%v", c.remoteAddr.String(), err) + } + } + c.releaseUDP() + return el.handleAction(c, action) + } + + if _, ok := el.connections[c]; !ok { + return nil // ignore stale wakes. + } + + action := el.eventHandler.OnClose(c, err) + if err := c.rawConn.Close(); err != nil { + el.getLogger().Errorf("failed to close connection(%s), error:%v", c.remoteAddr.String(), err) + } + delete(el.connections, c) + el.addConn(-1) + c.releaseTCP() + + return el.handleAction(c, action) +} + +func (el *eventloop) handleAction(c *conn, action Action) error { + switch action { + case None: + return nil + case Close: + return el.close(c, nil) + case Shutdown: + return errors.ErrEngineShutdown + default: + return nil + } +} diff --git a/gnet.go b/gnet.go index e0d9e18e3..ed8c9415d 100644 --- a/gnet.go +++ b/gnet.go @@ -88,7 +88,7 @@ func (s Engine) Stop(ctx context.Context) error { return errors.ErrEngineInShutdown } - s.eng.signalShutdown() + s.eng.shutdown(nil) ticker := time.NewTicker(shutdownPollInterval) defer ticker.Stop() @@ -440,7 +440,7 @@ func Stop(ctx context.Context, protoAddr string) error { var eng *engine if s, ok := allEngines.Load(protoAddr); ok { eng = s.(*engine) - eng.signalShutdown() + eng.shutdown(nil) defer allEngines.Delete(protoAddr) } else { return errors.ErrEngineInShutdown @@ -474,10 +474,3 @@ func parseProtoAddr(addr string) (network, address string) { } return } - -func bool2int(b bool) int { - if b { - return 1 - } - return 0 -} diff --git a/gnet_test.go b/gnet_test.go index f2874c0b6..611eecb10 100644 --- a/gnet_test.go +++ b/gnet_test.go @@ -21,12 +21,10 @@ import ( "context" "encoding/binary" "errors" - "fmt" "io" "math/rand" "net" "runtime" - "sync" "sync/atomic" "testing" "time" @@ -34,7 +32,6 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.uber.org/zap" - "golang.org/x/sys/unix" gerr "github.com/panjf2000/gnet/v2/pkg/errors" "github.com/panjf2000/gnet/v2/pkg/logging" @@ -305,12 +302,18 @@ func (s *testServer) OnTraffic(c Conn) (action Action) { bs[0] = buf.B[:mid] bs[1] = buf.B[mid:] _ = c.AsyncWritev(bs, func(c Conn, err error) error { - logging.Debugf("conn=%s done writev: %v", c.RemoteAddr().String(), err) + if c.RemoteAddr() != nil { + logging.Debugf("conn=%s done writev: %v", c.RemoteAddr().String(), err) + } + bbPool.Put(buf) return nil }) } else { _ = c.AsyncWrite(buf.Bytes(), func(c Conn, err error) error { - logging.Debugf("conn=%s done write: %v", c.RemoteAddr().String(), err) + if c.RemoteAddr() != nil { + logging.Debugf("conn=%s done write: %v", c.RemoteAddr().String(), err) + } + bbPool.Put(buf) return nil }) } @@ -335,7 +338,7 @@ func (s *testServer) OnTraffic(c Conn) (action Action) { fd, err := c.Dup() assert.NoError(s.tester, err) assert.Greater(s.tester, fd, 0) - assert.NoErrorf(s.tester, unix.Close(fd), "close error") + assert.NoErrorf(s.tester, SysClose(fd), "close error") assert.NoErrorf(s.tester, c.SetReadBuffer(streamLen), "set read buffer error") assert.NoErrorf(s.tester, c.SetWriteBuffer(streamLen), "set write buffer error") if s.network == "tcp" { @@ -349,6 +352,7 @@ func (s *testServer) OnTraffic(c Conn) (action Action) { } func (s *testServer) OnTick() (delay time.Duration, action Action) { + delay = time.Second / 5 if atomic.CompareAndSwapInt32(&s.started, 0, 1) { for i := 0; i < s.nclients; i++ { atomic.AddInt32(&s.clientActive, 1) @@ -362,7 +366,6 @@ func (s *testServer) OnTick() (delay time.Duration, action Action) { action = Shutdown return } - delay = time.Second / 5 return } @@ -432,171 +435,6 @@ func startClient(t *testing.T, network, addr string, multicore, async bool) { } } -// NOTE: TestServeMulticast can fail with "write: no buffer space available" on wifi interface. -func TestServeMulticast(t *testing.T) { - t.Run("IPv4", func(t *testing.T) { - // 224.0.0.169 is an unassigned address from the Local Network Control Block - // https://www.iana.org/assignments/multicast-addresses/multicast-addresses.xhtml#multicast-addresses-1 - t.Run("udp-multicast", func(t *testing.T) { - testMulticast(t, "224.0.0.169:9991", false, false, -1, 10) - }) - t.Run("udp-multicast-reuseport", func(t *testing.T) { - testMulticast(t, "224.0.0.169:9991", true, false, -1, 10) - }) - t.Run("udp-multicast-reuseaddr", func(t *testing.T) { - testMulticast(t, "224.0.0.169:9991", false, true, -1, 10) - }) - }) - t.Run("IPv6", func(t *testing.T) { - iface, err := findLoopbackInterface() - require.NoError(t, err) - if iface.Flags&net.FlagMulticast != net.FlagMulticast { - t.Skip("multicast is not supported on loopback interface") - } - // ff02::3 is an unassigned address from Link-Local Scope Multicast Addresses - // https://www.iana.org/assignments/ipv6-multicast-addresses/ipv6-multicast-addresses.xhtml#link-local - t.Run("udp-multicast", func(t *testing.T) { - testMulticast(t, fmt.Sprintf("[ff02::3%%%s]:9991", iface.Name), false, false, iface.Index, 10) - }) - t.Run("udp-multicast-reuseport", func(t *testing.T) { - testMulticast(t, fmt.Sprintf("[ff02::3%%%s]:9991", iface.Name), true, false, iface.Index, 10) - }) - t.Run("udp-multicast-reuseaddr", func(t *testing.T) { - testMulticast(t, fmt.Sprintf("[ff02::3%%%s]:9991", iface.Name), false, true, iface.Index, 10) - }) - }) -} - -func findLoopbackInterface() (*net.Interface, error) { - ifaces, err := net.Interfaces() - if err != nil { - return nil, err - } - for _, iface := range ifaces { - if iface.Flags&net.FlagLoopback == net.FlagLoopback { - return &iface, nil - } - } - return nil, errors.New("no loopback interface") -} - -func testMulticast(t *testing.T, addr string, reuseport, reuseaddr bool, index, nclients int) { - ts := &testMcastServer{ - t: t, - addr: addr, - nclients: nclients, - } - options := []Option{ - WithReuseAddr(reuseaddr), - WithReusePort(reuseport), - WithSocketRecvBuffer(2 * nclients * 1024), // enough space to receive messages from nclients to eliminate dropped packets - WithTicker(true), - } - if index != -1 { - options = append(options, WithMulticastInterfaceIndex(index)) - } - err := Run(ts, "udp://"+addr, options...) - assert.NoError(t, err) -} - -type testMcastServer struct { - *BuiltinEventEngine - t *testing.T - mcast sync.Map - addr string - nclients int - started int32 - active int32 -} - -func (s *testMcastServer) startMcastClient() { - rand.Seed(time.Now().UnixNano()) - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - defer cancel() - c, err := net.Dial("udp", s.addr) - require.NoError(s.t, err) - defer c.Close() - ch := make(chan []byte, 10000) - s.mcast.Store(c.LocalAddr().String(), ch) - duration := time.Duration((rand.Float64()*2+1)*float64(time.Second)) / 2 - s.t.Logf("test duration: %dms", duration/time.Millisecond) - start := time.Now() - for time.Since(start) < duration { - reqData := make([]byte, 1024) - _, err = rand.Read(reqData) - require.NoError(s.t, err) - _, err = c.Write(reqData) - require.NoError(s.t, err) - // Workaround for MacOS "write: no buffer space available" error messages - // https://developer.apple.com/forums/thread/42334 - time.Sleep(time.Millisecond) - select { - case respData := <-ch: - require.Equalf(s.t, reqData, respData, "response mismatch, length of bytes: %d vs %d", len(reqData), len(respData)) - case <-ctx.Done(): - require.Fail(s.t, "timeout receiving message") - return - } - } -} - -func (s *testMcastServer) OnTraffic(c Conn) (action Action) { - buf, _ := c.Next(-1) - b := make([]byte, len(buf)) - copy(b, buf) - ch, ok := s.mcast.Load(c.RemoteAddr().String()) - require.True(s.t, ok) - ch.(chan []byte) <- b - return -} - -func (s *testMcastServer) OnTick() (delay time.Duration, action Action) { - if atomic.CompareAndSwapInt32(&s.started, 0, 1) { - for i := 0; i < s.nclients; i++ { - atomic.AddInt32(&s.active, 1) - go func() { - s.startMcastClient() - atomic.AddInt32(&s.active, -1) - }() - } - } - if atomic.LoadInt32(&s.active) == 0 { - action = Shutdown - return - } - delay = time.Second / 5 - return -} - -type testMulticastBindServer struct { - *BuiltinEventEngine -} - -func (t *testMulticastBindServer) OnTick() (delay time.Duration, action Action) { - action = Shutdown - return -} - -func TestMulticastBindIPv4(t *testing.T) { - ts := &testMulticastBindServer{} - iface, err := findLoopbackInterface() - require.NoError(t, err) - err = Run(ts, "udp://224.0.0.169:9991", - WithMulticastInterfaceIndex(iface.Index), - WithTicker(true)) - assert.NoError(t, err) -} - -func TestMulticastBindIPv6(t *testing.T) { - ts := &testMulticastBindServer{} - iface, err := findLoopbackInterface() - require.NoError(t, err) - err = Run(ts, fmt.Sprintf("udp://[ff02::3%%%s]:9991", iface.Name), - WithMulticastInterfaceIndex(iface.Index), - WithTicker(true)) - assert.NoError(t, err) -} - func TestDefaultGnetServer(t *testing.T) { svr := BuiltinEventEngine{} svr.OnBoot(Engine{}) @@ -634,12 +472,12 @@ type testTickServer struct { } func (t *testTickServer) OnTick() (delay time.Duration, action Action) { + delay = time.Millisecond * 10 if t.count == 25 { action = Shutdown return } t.count++ - delay = time.Millisecond * 10 return } @@ -771,7 +609,7 @@ func testShutdown(t *testing.T, network, addr string) { events := &testShutdownServer{tester: t, network: network, addr: addr, N: 10} err := Run(events, network+"://"+addr, WithTicker(true), WithReadBufferCap(512), WithWriteBufferCap(512)) assert.NoError(t, err) - require.Equal(t, int(events.clients), 0, "did not call close on all clients") + require.Equal(t, 0, int(events.clients), "did not close all clients") } func TestCloseActionError(t *testing.T) { @@ -1182,14 +1020,14 @@ func testEngineStop(t *testing.T, network, addr string) { events1 := &testStopEngine{tester: t, network: network, addr: addr, protoAddr: network + "://" + addr, name: "1", stopIter: 2} events2 := &testStopEngine{tester: t, network: network, addr: addr, protoAddr: network + "://" + addr, name: "2", stopIter: 5} - result1 := make(chan error) + result1 := make(chan error, 1) go func() { err := Run(events1, events1.protoAddr, WithTicker(true), WithReuseAddr(true), WithReusePort(true)) result1 <- err }() // ensure the first handler processes before starting the next since the delay per tick is 100ms time.Sleep(150 * time.Millisecond) - result2 := make(chan error) + result2 := make(chan error, 1) go func() { err := Run(events2, events2.protoAddr, WithTicker(true), WithReuseAddr(true), WithReusePort(true)) result2 <- err diff --git a/go.mod b/go.mod index 401f934c8..e729d795d 100644 --- a/go.mod +++ b/go.mod @@ -5,6 +5,7 @@ require ( github.com/stretchr/testify v1.8.1 github.com/valyala/bytebufferpool v1.0.0 go.uber.org/zap v1.21.0 + golang.org/x/sync v0.1.0 golang.org/x/sys v0.3.0 gopkg.in/natefinch/lumberjack.v2 v2.0.0 ) diff --git a/internal/socket/sock_posix.go b/internal/socket/sock_posix.go index db43ff542..f2a2d810c 100644 --- a/internal/socket/sock_posix.go +++ b/internal/socket/sock_posix.go @@ -1,20 +1,20 @@ -/* - * Copyright 2009 The Go Authors. All rights reserved. - * Copyright (c) 2022 Andy Pan. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * - */ +// Copyright 2009 The Go Authors. All rights reserved. +// Copyright (c) 2022 Andy Pan. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//go:build linux || freebsd || dragonfly || darwin +// +build linux freebsd dragonfly darwin package socket diff --git a/listener.go b/listener_unix.go similarity index 100% rename from listener.go rename to listener_unix.go diff --git a/listener_windows.go b/listener_windows.go new file mode 100644 index 000000000..2b3c7cf47 --- /dev/null +++ b/listener_windows.go @@ -0,0 +1,110 @@ +/* + * Copyright (c) 2023 Andy Pan. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package gnet + +import ( + "context" + "net" + "os" + "sync" + "syscall" + + "golang.org/x/sys/windows" + + errorx "github.com/panjf2000/gnet/v2/pkg/errors" + "github.com/panjf2000/gnet/v2/pkg/logging" +) + +type listener struct { + network string + address string + once sync.Once + ln net.Listener + pc net.PacketConn + addr net.Addr +} + +func (l *listener) dup() (int, string, error) { + var ( + file *os.File + err error + ) + if l.pc != nil { + file, err = l.pc.(*net.UDPConn).File() + } else { + file, err = l.ln.(interface{ File() (*os.File, error) }).File() + } + if err != nil { + return 0, "dup", err + } + return int(file.Fd()), "", nil +} + +func (l *listener) close() { + l.once.Do(func() { + if l.pc != nil { + logging.Error(os.NewSyscallError("close", l.pc.Close())) + return + } + logging.Error(os.NewSyscallError("close", l.ln.Close())) + if l.network == "unix" { + logging.Error(os.RemoveAll(l.address)) + } + }) +} + +func initListener(network, addr string, options *Options) (l *listener, err error) { + lc := net.ListenConfig{ + Control: func(network, address string, c syscall.RawConn) error { + return c.Control(func(fd uintptr) { + if network != "unix" && (options.ReuseAddr || options.ReusePort) { + _ = windows.SetsockoptInt(windows.Handle(fd), windows.SOL_SOCKET, windows.SO_REUSEADDR, 1) + } + if options.TCPNoDelay == options.TCPNoDelay { + _ = windows.SetsockoptInt(windows.Handle(fd), windows.IPPROTO_TCP, windows.TCP_NODELAY, 1) + } + if options.SocketRecvBuffer > 0 { + _ = windows.SetsockoptInt(windows.Handle(fd), windows.SOL_SOCKET, windows.SO_RCVBUF, options.SocketRecvBuffer) + } + if options.SocketSendBuffer > 0 { + _ = windows.SetsockoptInt(windows.Handle(fd), windows.SOL_SOCKET, windows.SO_SNDBUF, options.SocketSendBuffer) + } + }) + }, + KeepAlive: options.TCPKeepAlive, + } + l = &listener{network: network, address: addr} + switch network { + case "udp", "udp4", "udp6": + if l.pc, err = lc.ListenPacket(context.Background(), network, addr); err != nil { + return nil, err + } + l.addr = l.pc.LocalAddr() + case "unix": + logging.Error(os.Remove(addr)) + fallthrough + case "tcp", "tcp4", "tcp6": + if l.ln, err = lc.Listen(context.Background(), network, addr); err != nil { + return nil, err + } + l.addr = l.ln.Addr() + default: + err = errorx.ErrUnsupportedProtocol + } + return +} diff --git a/os_unix_test.go b/os_unix_test.go new file mode 100644 index 000000000..a677494a7 --- /dev/null +++ b/os_unix_test.go @@ -0,0 +1,204 @@ +// Copyright (c) 2023 Andy Pan. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//go:build linux || freebsd || dragonfly || darwin +// +build linux freebsd dragonfly darwin + +package gnet + +import ( + "context" + "errors" + "fmt" + "math/rand" + "net" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "golang.org/x/sys/unix" +) + +var ( + SysClose = unix.Close + NetDial = net.Dial +) + +// NOTE: TestServeMulticast can fail with "write: no buffer space available" on wifi interface. +func TestServeMulticast(t *testing.T) { + t.Run("IPv4", func(t *testing.T) { + // 224.0.0.169 is an unassigned address from the Local Network Control Block + // https://www.iana.org/assignments/multicast-addresses/multicast-addresses.xhtml#multicast-addresses-1 + t.Run("udp-multicast", func(t *testing.T) { + testMulticast(t, "224.0.0.169:9991", false, false, -1, 10) + }) + t.Run("udp-multicast-reuseport", func(t *testing.T) { + testMulticast(t, "224.0.0.169:9991", true, false, -1, 10) + }) + t.Run("udp-multicast-reuseaddr", func(t *testing.T) { + testMulticast(t, "224.0.0.169:9991", false, true, -1, 10) + }) + }) + t.Run("IPv6", func(t *testing.T) { + iface, err := findLoopbackInterface() + require.NoError(t, err) + if iface.Flags&net.FlagMulticast != net.FlagMulticast { + t.Skip("multicast is not supported on loopback interface") + } + // ff02::3 is an unassigned address from Link-Local Scope Multicast Addresses + // https://www.iana.org/assignments/ipv6-multicast-addresses/ipv6-multicast-addresses.xhtml#link-local + t.Run("udp-multicast", func(t *testing.T) { + testMulticast(t, fmt.Sprintf("[ff02::3%%%s]:9991", iface.Name), false, false, iface.Index, 10) + }) + t.Run("udp-multicast-reuseport", func(t *testing.T) { + testMulticast(t, fmt.Sprintf("[ff02::3%%%s]:9991", iface.Name), true, false, iface.Index, 10) + }) + t.Run("udp-multicast-reuseaddr", func(t *testing.T) { + testMulticast(t, fmt.Sprintf("[ff02::3%%%s]:9991", iface.Name), false, true, iface.Index, 10) + }) + }) +} + +func findLoopbackInterface() (*net.Interface, error) { + ifaces, err := net.Interfaces() + if err != nil { + return nil, err + } + for _, iface := range ifaces { + if iface.Flags&net.FlagLoopback == net.FlagLoopback { + return &iface, nil + } + } + return nil, errors.New("no loopback interface") +} + +func testMulticast(t *testing.T, addr string, reuseport, reuseaddr bool, index, nclients int) { + ts := &testMcastServer{ + t: t, + addr: addr, + nclients: nclients, + } + options := []Option{ + WithReuseAddr(reuseaddr), + WithReusePort(reuseport), + WithSocketRecvBuffer(2 * nclients * 1024), // enough space to receive messages from nclients to eliminate dropped packets + WithTicker(true), + } + if index != -1 { + options = append(options, WithMulticastInterfaceIndex(index)) + } + err := Run(ts, "udp://"+addr, options...) + assert.NoError(t, err) +} + +type testMcastServer struct { + *BuiltinEventEngine + t *testing.T + mcast sync.Map + addr string + nclients int + started int32 + active int32 +} + +func (s *testMcastServer) startMcastClient() { + rand.Seed(time.Now().UnixNano()) + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + c, err := net.Dial("udp", s.addr) + require.NoError(s.t, err) + defer c.Close() + ch := make(chan []byte, 10000) + s.mcast.Store(c.LocalAddr().String(), ch) + duration := time.Duration((rand.Float64()*2+1)*float64(time.Second)) / 2 + s.t.Logf("test duration: %dms", duration/time.Millisecond) + start := time.Now() + for time.Since(start) < duration { + reqData := make([]byte, 1024) + _, err = rand.Read(reqData) + require.NoError(s.t, err) + _, err = c.Write(reqData) + require.NoError(s.t, err) + // Workaround for MacOS "write: no buffer space available" error messages + // https://developer.apple.com/forums/thread/42334 + time.Sleep(time.Millisecond) + select { + case respData := <-ch: + require.Equalf(s.t, reqData, respData, "response mismatch, length of bytes: %d vs %d", len(reqData), len(respData)) + case <-ctx.Done(): + require.Fail(s.t, "timeout receiving message") + return + } + } +} + +func (s *testMcastServer) OnTraffic(c Conn) (action Action) { + buf, _ := c.Next(-1) + b := make([]byte, len(buf)) + copy(b, buf) + ch, ok := s.mcast.Load(c.RemoteAddr().String()) + require.True(s.t, ok) + ch.(chan []byte) <- b + return +} + +func (s *testMcastServer) OnTick() (delay time.Duration, action Action) { + if atomic.CompareAndSwapInt32(&s.started, 0, 1) { + for i := 0; i < s.nclients; i++ { + atomic.AddInt32(&s.active, 1) + go func() { + s.startMcastClient() + atomic.AddInt32(&s.active, -1) + }() + } + } + if atomic.LoadInt32(&s.active) == 0 { + action = Shutdown + return + } + delay = time.Second / 5 + return +} + +type testMulticastBindServer struct { + *BuiltinEventEngine +} + +func (t *testMulticastBindServer) OnTick() (delay time.Duration, action Action) { + action = Shutdown + return +} + +func TestMulticastBindIPv4(t *testing.T) { + ts := &testMulticastBindServer{} + iface, err := findLoopbackInterface() + require.NoError(t, err) + err = Run(ts, "udp://224.0.0.169:9991", + WithMulticastInterfaceIndex(iface.Index), + WithTicker(true)) + assert.NoError(t, err) +} + +func TestMulticastBindIPv6(t *testing.T) { + ts := &testMulticastBindServer{} + iface, err := findLoopbackInterface() + require.NoError(t, err) + err = Run(ts, fmt.Sprintf("udp://[ff02::3%%%s]:9991", iface.Name), + WithMulticastInterfaceIndex(iface.Index), + WithTicker(true)) + assert.NoError(t, err) +} diff --git a/os_windows_test.go b/os_windows_test.go new file mode 100644 index 000000000..b5d6ff579 --- /dev/null +++ b/os_windows_test.go @@ -0,0 +1,36 @@ +// Copyright (c) 2023 Andy Pan. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//go:build windows +// +build windows + +package gnet + +import ( + "net" + "syscall" +) + +func SysClose(fd int) error { + return syscall.CloseHandle(syscall.Handle(fd)) +} + +func NetDial(network, addr string) (net.Conn, error) { + if network == "unix" { + laddr, _ := net.ResolveUnixAddr(network, unixAddr(addr)) + raddr, _ := net.ResolveUnixAddr(network, addr) + return net.DialUnix(network, laddr, raddr) + } + return net.Dial(network, addr) +} diff --git a/pkg/errors/errors.go b/pkg/errors/errors.go index d10282b86..a5fc92fbe 100644 --- a/pkg/errors/errors.go +++ b/pkg/errors/errors.go @@ -43,4 +43,6 @@ var ( ErrNegativeSize = errors.New("negative size is invalid") // ErrNoIPv4AddressOnInterface occurs when an IPv4 multicast address is set on an interface but IPv4 is not configured. ErrNoIPv4AddressOnInterface = errors.New("no IPv4 address on interface") + // ErrInvalidConn occurs when the connection is invalid. + ErrInvalidConn = errors.New("invalid connection") ) diff --git a/reactor_default_bsd.go b/reactor_default_bsd.go index c87b542f8..81face4b2 100644 --- a/reactor_default_bsd.go +++ b/reactor_default_bsd.go @@ -27,33 +27,31 @@ import ( "github.com/panjf2000/gnet/v2/pkg/errors" ) -func (el *eventloop) activateMainReactor(lockOSThread bool) { - if lockOSThread { +func (el *eventloop) activateMainReactor() error { + if el.engine.opts.LockOSThread { runtime.LockOSThread() defer runtime.UnlockOSThread() } - defer el.engine.signalShutdown() - err := el.poller.Polling(func(fd int, filter int16) error { return el.engine.accept(fd, filter) }) if err == errors.ErrEngineShutdown { el.engine.opts.Logger.Debugf("main reactor is exiting in terms of the demand from user, %v", err) + err = nil } else if err != nil { el.engine.opts.Logger.Errorf("main reactor is exiting due to error: %v", err) } + + el.engine.shutdown(err) + + return err } -func (el *eventloop) activateSubReactor(lockOSThread bool) { - if lockOSThread { +func (el *eventloop) activateSubReactor() error { + if el.engine.opts.LockOSThread { runtime.LockOSThread() defer runtime.UnlockOSThread() } - defer func() { - el.closeAllSockets() - el.engine.signalShutdown() - }() - err := el.poller.Polling(func(fd int, filter int16) (err error) { if c, ack := el.connections[fd]; ack { switch filter { @@ -71,23 +69,23 @@ func (el *eventloop) activateSubReactor(lockOSThread bool) { }) if err == errors.ErrEngineShutdown { el.engine.opts.Logger.Debugf("event-loop(%d) is exiting in terms of the demand from user, %v", el.idx, err) + err = nil } else if err != nil { el.engine.opts.Logger.Errorf("event-loop(%d) is exiting due to error: %v", el.idx, err) } + + el.closeAllSockets() + el.engine.shutdown(err) + + return err } -func (el *eventloop) run(lockOSThread bool) { - if lockOSThread { +func (el *eventloop) run() error { + if el.engine.opts.LockOSThread { runtime.LockOSThread() defer runtime.UnlockOSThread() } - defer func() { - el.closeAllSockets() - el.ln.close() - el.engine.signalShutdown() - }() - err := el.poller.Polling(func(fd int, filter int16) (err error) { if c, ack := el.connections[fd]; ack { switch filter { @@ -104,5 +102,16 @@ func (el *eventloop) run(lockOSThread bool) { } return el.accept(fd, filter) }) - el.getLogger().Debugf("event-loop(%d) is exiting due to error: %v", el.idx, err) + if err == errors.ErrEngineShutdown { + el.engine.opts.Logger.Debugf("event-loop(%d) is exiting in terms of the demand from user, %v", el.idx, err) + err = nil + } else if err != nil { + el.engine.opts.Logger.Errorf("event-loop(%d) is exiting due to error: %v", el.idx, err) + } + + el.closeAllSockets() + el.ln.close() + el.engine.shutdown(err) + + return err } diff --git a/reactor_default_linux.go b/reactor_default_linux.go index 765f75196..52fb609d2 100644 --- a/reactor_default_linux.go +++ b/reactor_default_linux.go @@ -24,33 +24,31 @@ import ( "github.com/panjf2000/gnet/v2/pkg/errors" ) -func (el *eventloop) activateMainReactor(lockOSThread bool) { - if lockOSThread { +func (el *eventloop) activateMainReactor() error { + if el.engine.opts.LockOSThread { runtime.LockOSThread() defer runtime.UnlockOSThread() } - defer el.engine.signalShutdown() - err := el.poller.Polling(func(fd int, ev uint32) error { return el.engine.accept(fd, ev) }) if err == errors.ErrEngineShutdown { el.engine.opts.Logger.Debugf("main reactor is exiting in terms of the demand from user, %v", err) + err = nil } else if err != nil { el.engine.opts.Logger.Errorf("main reactor is exiting due to error: %v", err) } + + el.engine.shutdown(err) + + return err } -func (el *eventloop) activateSubReactor(lockOSThread bool) { - if lockOSThread { +func (el *eventloop) activateSubReactor() error { + if el.engine.opts.LockOSThread { runtime.LockOSThread() defer runtime.UnlockOSThread() } - defer func() { - el.closeAllSockets() - el.engine.signalShutdown() - }() - err := el.poller.Polling(func(fd int, ev uint32) error { if c, ack := el.connections[fd]; ack { // Don't change the ordering of processing EPOLLOUT | EPOLLRDHUP / EPOLLIN unless you're 100% @@ -78,23 +76,23 @@ func (el *eventloop) activateSubReactor(lockOSThread bool) { if err == errors.ErrEngineShutdown { el.engine.opts.Logger.Debugf("event-loop(%d) is exiting in terms of the demand from user, %v", el.idx, err) + err = nil } else if err != nil { el.engine.opts.Logger.Errorf("event-loop(%d) is exiting due to error: %v", el.idx, err) } + + el.closeAllSockets() + el.engine.shutdown(err) + + return err } -func (el *eventloop) run(lockOSThread bool) { - if lockOSThread { +func (el *eventloop) run() error { + if el.engine.opts.LockOSThread { runtime.LockOSThread() defer runtime.UnlockOSThread() } - defer func() { - el.closeAllSockets() - el.ln.close() - el.engine.signalShutdown() - }() - err := el.poller.Polling(func(fd int, ev uint32) error { if c, ok := el.connections[fd]; ok { // Don't change the ordering of processing EPOLLOUT | EPOLLRDHUP / EPOLLIN unless you're 100% @@ -121,5 +119,16 @@ func (el *eventloop) run(lockOSThread bool) { return el.accept(fd, ev) }) - el.getLogger().Debugf("event-loop(%d) is exiting due to error: %v", el.idx, err) + if err == errors.ErrEngineShutdown { + el.engine.opts.Logger.Debugf("event-loop(%d) is exiting in terms of the demand from user, %v", el.idx, err) + err = nil + } else if err != nil { + el.engine.opts.Logger.Errorf("event-loop(%d) is exiting due to error: %v", el.idx, err) + } + + el.closeAllSockets() + el.ln.close() + el.engine.shutdown(err) + + return err } diff --git a/reactor_optimized_bsd.go b/reactor_optimized_bsd.go index e628cf6c1..e6e25d58b 100644 --- a/reactor_optimized_bsd.go +++ b/reactor_optimized_bsd.go @@ -24,53 +24,55 @@ import ( "github.com/panjf2000/gnet/v2/pkg/errors" ) -func (el *eventloop) activateMainReactor(lockOSThread bool) { - if lockOSThread { +func (el *eventloop) activateMainReactor() error { + if el.engine.opts.LockOSThread { runtime.LockOSThread() defer runtime.UnlockOSThread() } - defer el.engine.signalShutdown() - err := el.poller.Polling() if err == errors.ErrEngineShutdown { el.engine.opts.Logger.Debugf("main reactor is exiting in terms of the demand from user, %v", err) } else if err != nil { el.engine.opts.Logger.Errorf("main reactor is exiting due to error: %v", err) } + + el.engine.shutdown(err) + + return err } -func (el *eventloop) activateSubReactor(lockOSThread bool) { - if lockOSThread { +func (el *eventloop) activateSubReactor() error { + if el.engine.opts.LockOSThread { runtime.LockOSThread() defer runtime.UnlockOSThread() } - defer func() { - el.closeAllSockets() - el.engine.signalShutdown() - }() - err := el.poller.Polling() if err == errors.ErrEngineShutdown { el.engine.opts.Logger.Debugf("event-loop(%d) is exiting in terms of the demand from user, %v", el.idx, err) } else if err != nil { el.engine.opts.Logger.Errorf("event-loop(%d) is exiting due to error: %v", el.idx, err) } + + el.closeAllSockets() + el.engine.shutdown(err) + + return err } -func (el *eventloop) run(lockOSThread bool) { - if lockOSThread { +func (el *eventloop) run() error { + if el.engine.opts.LockOSThread { runtime.LockOSThread() defer runtime.UnlockOSThread() } - defer func() { - el.closeAllSockets() - el.ln.close() - el.engine.signalShutdown() - }() - err := el.poller.Polling() el.getLogger().Debugf("event-loop(%d) is exiting due to error: %v", el.idx, err) + + el.closeAllSockets() + el.ln.close() + el.engine.shutdown(err) + + return err } diff --git a/reactor_optimized_linux.go b/reactor_optimized_linux.go index 319925468..80a5e32cc 100644 --- a/reactor_optimized_linux.go +++ b/reactor_optimized_linux.go @@ -23,53 +23,55 @@ import ( "github.com/panjf2000/gnet/v2/pkg/errors" ) -func (el *eventloop) activateMainReactor(lockOSThread bool) { - if lockOSThread { +func (el *eventloop) activateMainReactor() error { + if el.engine.opts.LockOSThread { runtime.LockOSThread() defer runtime.UnlockOSThread() } - defer el.engine.signalShutdown() - err := el.poller.Polling() if err == errors.ErrEngineShutdown { el.engine.opts.Logger.Debugf("main reactor is exiting in terms of the demand from user, %v", err) } else if err != nil { el.engine.opts.Logger.Errorf("main reactor is exiting due to error: %v", err) } + + el.engine.shutdown(err) + + return err } -func (el *eventloop) activateSubReactor(lockOSThread bool) { - if lockOSThread { +func (el *eventloop) activateSubReactor() error { + if el.engine.opts.LockOSThread { runtime.LockOSThread() defer runtime.UnlockOSThread() } - defer func() { - el.closeAllSockets() - el.engine.signalShutdown() - }() - err := el.poller.Polling() if err == errors.ErrEngineShutdown { el.engine.opts.Logger.Debugf("event-loop(%d) is exiting in terms of the demand from user, %v", el.idx, err) } else if err != nil { el.engine.opts.Logger.Errorf("event-loop(%d) is exiting due to error: %v", el.idx, err) } + + el.closeAllSockets() + el.engine.shutdown(err) + + return err } -func (el *eventloop) run(lockOSThread bool) { - if lockOSThread { +func (el *eventloop) run() error { + if el.engine.opts.LockOSThread { runtime.LockOSThread() defer runtime.UnlockOSThread() } - defer func() { - el.closeAllSockets() - el.ln.close() - el.engine.signalShutdown() - }() - err := el.poller.Polling() el.getLogger().Debugf("event-loop(%d) is exiting due to error: %v", el.idx, err) + + el.closeAllSockets() + el.ln.close() + el.engine.shutdown(err) + + return err }