From 59aa7cbf64b6a77d30b437afbecf7909938375d5 Mon Sep 17 00:00:00 2001 From: wangzhuowei Date: Wed, 10 Jan 2024 11:23:54 +0800 Subject: [PATCH] fix: race condition --- connection_impl.go | 17 +++----- connection_reactor.go | 58 ++++++++++++++-------------- connection_test.go | 10 +---- fd_operator.go | 19 +++------ poll.go | 22 ++++------- poll_default_bsd.go | 21 +--------- poll_default_bsd_test.go | 83 ++++++++++++++++++++++++++++++++++++++++ poll_default_linux.go | 23 +---------- 8 files changed, 133 insertions(+), 120 deletions(-) create mode 100644 poll_default_bsd_test.go diff --git a/connection_impl.go b/connection_impl.go index 43ea5cb9..57dab9ab 100644 --- a/connection_impl.go +++ b/connection_impl.go @@ -503,17 +503,10 @@ func (c *connection) flush() error { if c.outputBuffer.IsEmpty() { return nil } - if c.operator.getMode() == ophup { - // triggered read throttled, so here shouldn't trigger read event again - err = c.operator.Control(PollHup2W) - } else { - err = c.operator.Control(PollR2RW) - } - c.operator.done() - if err != nil { - return Exception(err, "when flush") - } + // no need to check if resume write successfully + // if resume failed, the connection will be triggered triggerWrite(err), and waitFlush will return err + c.resumeWrite() return c.waitFlush() } @@ -546,8 +539,8 @@ func (c *connection) waitFlush() (err error) { default: } // if timeout, remove write event from poller - // we cannot flush it again, since we don't if the poller is still process outputBuffer - c.operator.Control(PollRW2R) + // we cannot flush it again, since we don't know if the poller is still processing outputBuffer + c.pauseWrite() return Exception(ErrWriteTimeout, c.remoteAddr.String()) } } diff --git a/connection_reactor.go b/connection_reactor.go index 8fd582d4..3ac022d7 100644 --- a/connection_reactor.go +++ b/connection_reactor.go @@ -80,6 +80,12 @@ func (c *connection) closeBuffer() { // inputs implements FDOperator. func (c *connection) inputs(vs [][]byte) (rs [][]byte) { + // trigger throttle + if c.readBufferThreshold > 0 && int64(c.inputBuffer.Len()) >= c.readBufferThreshold { + c.pauseRead() + return + } + vs[0] = c.inputBuffer.book(c.bookSize, c.maxSize) return vs[:1] } @@ -123,6 +129,7 @@ func (c *connection) inputAck(n int) (err error) { func (c *connection) outputs(vs [][]byte) (rs [][]byte, supportZeroCopy bool) { if c.outputBuffer.IsEmpty() { c.pauseWrite() + c.triggerWrite(nil) return rs, c.supportZeroCopy } rs = c.outputBuffer.GetBytes(vs) @@ -137,50 +144,43 @@ func (c *connection) outputAck(n int) (err error) { } if c.outputBuffer.IsEmpty() { c.pauseWrite() + c.triggerWrite(nil) } return nil } +/* The race description of operator event monitoring +- Pause operation will remove old event monitor of operator +- Resume operation will add new event monitor of operator +- Only poller could use Pause to remove event monitor, and poller already hold the op.do() locker +- Only user could use Resume, and user's operation maybe compete with poller's operation +- If competition happen, because of all resume operation will monitor all events, it's safe to do that with a race condition. + * If resume first and pause latter, poller will monitor the accurate events it needs. + * If pause first and resume latter, poller will monitor the duplicate events which will be removed after next poller triggered. + And poller will ensure to remove the duplicate events. +- If there is no readBufferThreshold option, the code path will be more simple and efficient. +*/ + // pauseWrite removed the monitoring of write events. // pauseWrite used in poller func (c *connection) pauseWrite() { - switch c.operator.getMode() { - case opreadwrite: - c.operator.Control(PollRW2R) - case opwrite: - c.operator.Control(PollW2Hup) - } - c.triggerWrite(nil) + c.operator.Control(PollRW2R) +} + +// resumeWrite add the monitoring of write events. +// resumeWrite used by users +func (c *connection) resumeWrite() { + c.operator.Control(PollR2RW) } // pauseRead removed the monitoring of read events. // pauseRead used in poller func (c *connection) pauseRead() { - // Note that the poller ensure that every fd should read all left data in socket buffer before detach it. - // So the operator mode should never be ophup. - var changeTo PollEvent - switch c.operator.getMode() { - case opread: - changeTo = PollR2Hup - case opreadwrite: - changeTo = PollRW2W - } - if changeTo > 0 && atomic.CompareAndSwapInt32(&c.operator.throttled, 0, 1) { - c.operator.Control(changeTo) - } + c.operator.Control(PollRW2W) } // resumeRead add the monitoring of read events. // resumeRead used by users func (c *connection) resumeRead() { - var changeTo PollEvent - switch c.operator.getMode() { - case ophup: - changeTo = PollHup2R - case opwrite: - changeTo = PollW2RW - } - if changeTo > 0 && atomic.CompareAndSwapInt32(&c.operator.throttled, 1, 0) { - c.operator.Control(changeTo) - } + c.operator.Control(PollW2RW) } diff --git a/connection_test.go b/connection_test.go index 528f29d3..6d74d308 100644 --- a/connection_test.go +++ b/connection_test.go @@ -784,18 +784,10 @@ func TestConnectionReadThresholdWithClosed(t *testing.T) { MustNil(t, err) t.Logf("read non-throttled data") - // continue read throttled data + // continue read throttled data with EOF buf, err = connection.Reader().Next(5) - MustNil(t, err) - t.Logf("read throttled data: [%s]", buf) Equal(t, len(buf), 5) MustNil(t, err) - err = connection.Reader().Release() - MustNil(t, err) - Equal(t, connection.Reader().Len(), 0) - - _, err = connection.Reader().Next(1) - Assert(t, errors.Is(err, ErrEOF)) trigger <- struct{}{} return nil } diff --git a/fd_operator.go b/fd_operator.go index 89dae80f..1281c779 100644 --- a/fd_operator.go +++ b/fd_operator.go @@ -25,7 +25,6 @@ const ( opread int32 = 1 opwrite int32 = 2 opreadwrite int32 = 3 - ophup int32 = 4 ) // FDOperator is a collection of operations on file descriptors. @@ -51,8 +50,8 @@ type FDOperator struct { // poll is the registered location of the file descriptor. poll Poll - mode int32 - throttled int32 + // protect only detach once + detached int32 // private, used by operatorCache next *FDOperator @@ -61,6 +60,9 @@ type FDOperator struct { } func (op *FDOperator) Control(event PollEvent) error { + if event == PollDetach && atomic.AddInt32(&op.detached, 1) > 1 { + return nil + } return op.poll.Control(op, event) } @@ -68,14 +70,6 @@ func (op *FDOperator) Free() { op.poll.Free(op) } -func (op *FDOperator) getMode() int32 { - return atomic.LoadInt32(&op.mode) -} - -func (op *FDOperator) setMode(mode int32) { - atomic.StoreInt32(&op.mode, mode) -} - func (op *FDOperator) do() (can bool) { return atomic.CompareAndSwapInt32(&op.state, 1, 2) } @@ -112,6 +106,5 @@ func (op *FDOperator) reset() { op.Inputs, op.InputAck = nil, nil op.Outputs, op.OutputAck = nil, nil op.poll = nil - op.mode = 0 - op.throttled = 0 + op.detached = 0 } diff --git a/poll.go b/poll.go index ace07133..915b1f9d 100644 --- a/poll.go +++ b/poll.go @@ -48,31 +48,23 @@ type PollEvent int const ( // PollReadable is used to monitor whether the FDOperator registered by // listener and connection is readable or closed. - PollReadable PollEvent = 0x1 + PollReadable PollEvent = iota + 1 // PollWritable is used to monitor whether the FDOperator created by the dialer is writable or closed. // ET mode must be used (still need to poll hup after being writable) - PollWritable PollEvent = 0x2 + PollWritable // PollDetach is used to remove the FDOperator from poll. - PollDetach PollEvent = 0x3 + PollDetach // PollR2RW is used to monitor writable for FDOperator, // which is only called when the socket write buffer is full. - PollR2RW PollEvent = 0x4 + PollR2RW // PollRW2R is used to remove the writable monitor of FDOperator, generally used with PollR2RW. - PollRW2R PollEvent = 0x5 + PollRW2R // PollRW2W is used to remove the readable monitor of FDOperator. - PollRW2W PollEvent = 0x6 + PollRW2W // PollW2RW is used to add the readable monitor of FDOperator, generally used with PollRW2W. - PollW2RW PollEvent = 0x7 - PollW2Hup PollEvent = 0x8 - - // PollR2Hup is used to remove the readable monitor of FDOperator. - PollR2Hup PollEvent = 0x9 - // PollHup2R is used to add the readable monitor of FDOperator, generally used with PollR2Hup. - PollHup2R PollEvent = 0xA - // PollHup2W is used to add the writeable monitor of FDOperator. - PollHup2W PollEvent = 0xB + PollW2RW ) diff --git a/poll_default_bsd.go b/poll_default_bsd.go index 8fda9c35..47dcd1da 100644 --- a/poll_default_bsd.go +++ b/poll_default_bsd.go @@ -116,7 +116,7 @@ func (p *defaultPoll) Wait() error { } if triggerHup { // if peer closed with throttled state, we should ensure we read all left data to avoid data loss - if (triggerRead || atomic.LoadInt32(&operator.throttled) > 0) && operator.Inputs != nil { + if triggerRead && operator.Inputs != nil { var leftRead int // read all left data if peer send and close if leftRead, err = readall(operator, barriers[i]); err != nil && !errors.Is(err, ErrEOF) { @@ -183,14 +183,11 @@ func (p *defaultPoll) Control(operator *FDOperator, event PollEvent) error { switch event { case PollReadable: operator.inuse() - operator.setMode(opread) evs[0].Filter, evs[0].Flags = syscall.EVFILT_READ, syscall.EV_ADD|syscall.EV_ENABLE case PollWritable: operator.inuse() - operator.setMode(opwrite) evs[0].Filter, evs[0].Flags = syscall.EVFILT_WRITE, syscall.EV_ADD|syscall.EV_ENABLE case PollDetach: - operator.setMode(ophup) if operator.OnWrite != nil { // means WaitWrite finished evs[0].Filter, evs[0].Flags = syscall.EVFILT_WRITE, syscall.EV_DELETE } else { @@ -198,29 +195,13 @@ func (p *defaultPoll) Control(operator *FDOperator, event PollEvent) error { } p.delOperator(operator) case PollR2RW: - operator.setMode(opreadwrite) evs[0].Filter, evs[0].Flags = syscall.EVFILT_WRITE, syscall.EV_ADD|syscall.EV_ENABLE case PollRW2R: - operator.setMode(opread) evs[0].Filter, evs[0].Flags = syscall.EVFILT_WRITE, syscall.EV_DELETE case PollRW2W: - operator.setMode(opwrite) evs[0].Filter, evs[0].Flags = syscall.EVFILT_READ, syscall.EV_DELETE case PollW2RW: - operator.setMode(opreadwrite) evs[0].Filter, evs[0].Flags = syscall.EVFILT_READ, syscall.EV_ADD|syscall.EV_ENABLE - case PollR2Hup: - operator.setMode(ophup) - evs[0].Filter, evs[0].Flags = syscall.EVFILT_READ, syscall.EV_DELETE - case PollW2Hup: - operator.setMode(ophup) - evs[0].Filter, evs[0].Flags = syscall.EVFILT_WRITE, syscall.EV_DELETE - case PollHup2R: - operator.setMode(opread) - evs[0].Filter, evs[0].Flags = syscall.EVFILT_READ, syscall.EV_ADD|syscall.EV_ENABLE - case PollHup2W: - operator.setMode(opwrite) - evs[0].Filter, evs[0].Flags = syscall.EVFILT_WRITE, syscall.EV_ADD|syscall.EV_ENABLE } _, err := syscall.Kevent(p.fd, evs, nil, nil) return err diff --git a/poll_default_bsd_test.go b/poll_default_bsd_test.go new file mode 100644 index 00000000..92185f9e --- /dev/null +++ b/poll_default_bsd_test.go @@ -0,0 +1,83 @@ +// Copyright 2024 CloudWeGo Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//go:build darwin +// +build darwin + +package netpoll + +import ( + "syscall" + "testing" +) + +func TestKqueueEvent(t *testing.T) { + kqfd, err := syscall.Kqueue() + defer syscall.Close(kqfd) + _, err = syscall.Kevent(kqfd, []syscall.Kevent_t{{ + Ident: 0, + Filter: syscall.EVFILT_USER, + Flags: syscall.EV_ADD | syscall.EV_CLEAR, + }}, nil, nil) + MustNil(t, err) + + rfd, wfd := GetSysFdPairs() + defer syscall.Close(rfd) + defer syscall.Close(wfd) + + // add read event + changes := make([]syscall.Kevent_t, 1) + changes[0].Ident = uint64(rfd) + changes[0].Filter = syscall.EVFILT_READ + changes[0].Flags = syscall.EV_ADD + _, err = syscall.Kevent(kqfd, changes, nil, nil) + MustNil(t, err) + + // write + send := []byte("hello") + recv := make([]byte, 5) + _, err = syscall.Write(wfd, send) + MustNil(t, err) + + // check readable + events := make([]syscall.Kevent_t, 128) + n, err := syscall.Kevent(kqfd, nil, events, nil) + MustNil(t, err) + Equal(t, n, 1) + Assert(t, events[0].Filter == syscall.EVFILT_READ) + // read + _, err = syscall.Read(rfd, recv) + MustNil(t, err) + Equal(t, string(recv), string(send)) + + // delete read + changes[0].Ident = uint64(rfd) + changes[0].Filter = syscall.EVFILT_READ + changes[0].Flags = syscall.EV_DELETE + _, err = syscall.Kevent(kqfd, changes, nil, nil) + MustNil(t, err) + + // write + _, err = syscall.Write(wfd, send) + MustNil(t, err) + + // check readable + n, err = syscall.Kevent(kqfd, nil, events, &syscall.Timespec{Sec: 1}) + MustNil(t, err) + Equal(t, n, 0) + // read + _, err = syscall.Read(rfd, recv) + MustNil(t, err) + Equal(t, string(recv), string(send)) +} diff --git a/poll_default_linux.go b/poll_default_linux.go index f51e10c6..8637a577 100644 --- a/poll_default_linux.go +++ b/poll_default_linux.go @@ -168,8 +168,7 @@ func (p *defaultPoll) handler(events []epollevent) (closed bool) { } } if triggerHup { - // if peer closed with throttled state, we should ensure we read all left data to avoid data loss - if (triggerRead || atomic.LoadInt32(&operator.throttled) > 0) && operator.Inputs != nil { + if triggerRead && operator.Inputs != nil { // read all left data if peer send and close var leftRead int // read all left data if peer send and close @@ -245,41 +244,21 @@ func (p *defaultPoll) Control(operator *FDOperator, event PollEvent) error { switch event { case PollReadable: // server accept a new connection and wait read operator.inuse() - operator.setMode(opread) op, evt.events = syscall.EPOLL_CTL_ADD, syscall.EPOLLIN|syscall.EPOLLRDHUP|syscall.EPOLLERR case PollWritable: // client create a new connection and wait connect finished operator.inuse() - operator.setMode(opwrite) op, evt.events = syscall.EPOLL_CTL_ADD, EPOLLET|syscall.EPOLLOUT|syscall.EPOLLRDHUP|syscall.EPOLLERR case PollDetach: // deregister - if operator.getMode() == opdetach { - // protect only detach once - return nil - } - operator.setMode(opdetach) p.delOperator(operator) op, evt.events = syscall.EPOLL_CTL_DEL, syscall.EPOLLIN|syscall.EPOLLOUT|syscall.EPOLLRDHUP|syscall.EPOLLERR case PollR2RW: // connection wait read/write - operator.setMode(opreadwrite) op, evt.events = syscall.EPOLL_CTL_MOD, syscall.EPOLLIN|syscall.EPOLLOUT|syscall.EPOLLRDHUP|syscall.EPOLLERR case PollRW2R: // connection wait read - operator.setMode(opread) op, evt.events = syscall.EPOLL_CTL_MOD, syscall.EPOLLIN|syscall.EPOLLRDHUP|syscall.EPOLLERR case PollRW2W: - operator.setMode(opwrite) op, evt.events = syscall.EPOLL_CTL_MOD, syscall.EPOLLOUT|syscall.EPOLLRDHUP|syscall.EPOLLERR case PollW2RW: - operator.setMode(opreadwrite) op, evt.events = syscall.EPOLL_CTL_MOD, syscall.EPOLLIN|syscall.EPOLLOUT|syscall.EPOLLRDHUP|syscall.EPOLLERR - case PollR2Hup, PollW2Hup: - operator.setMode(ophup) - op, evt.events = syscall.EPOLL_CTL_MOD, syscall.EPOLLRDHUP|syscall.EPOLLERR - case PollHup2R: - operator.setMode(opread) - op, evt.events = syscall.EPOLL_CTL_MOD, syscall.EPOLLIN|syscall.EPOLLRDHUP|syscall.EPOLLERR - case PollHup2W: - operator.setMode(opwrite) - op, evt.events = syscall.EPOLL_CTL_MOD, syscall.EPOLLOUT|syscall.EPOLLRDHUP|syscall.EPOLLERR } return EpollCtl(p.fd, op, operator.FD, &evt) }