Skip to content

Commit

Permalink
fix: race condition
Browse files Browse the repository at this point in the history
  • Loading branch information
joway committed Jan 10, 2024
1 parent 5995a5b commit c99b425
Show file tree
Hide file tree
Showing 8 changed files with 133 additions and 120 deletions.
17 changes: 5 additions & 12 deletions connection_impl.go
Original file line number Diff line number Diff line change
Expand Up @@ -503,17 +503,10 @@ func (c *connection) flush() error {
if c.outputBuffer.IsEmpty() {
return nil
}
if c.operator.getMode() == ophup {
// triggered read throttled, so here shouldn't trigger read event again
err = c.operator.Control(PollHup2W)
} else {
err = c.operator.Control(PollR2RW)
}
c.operator.done()
if err != nil {
return Exception(err, "when flush")
}

// no need to check if resume write successfully
// if resume failed, the connection will be triggered triggerWrite(err), and waitFlush will return err
c.resumeWrite()
return c.waitFlush()
}

Expand Down Expand Up @@ -546,8 +539,8 @@ func (c *connection) waitFlush() (err error) {
default:
}
// if timeout, remove write event from poller
// we cannot flush it again, since we don't if the poller is still process outputBuffer
c.operator.Control(PollRW2R)
// we cannot flush it again, since we don't know if the poller is still processing outputBuffer
c.pauseWrite()
return Exception(ErrWriteTimeout, c.remoteAddr.String())
}
}
58 changes: 29 additions & 29 deletions connection_reactor.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,12 @@ func (c *connection) closeBuffer() {

// inputs implements FDOperator.
func (c *connection) inputs(vs [][]byte) (rs [][]byte) {
// trigger throttle
if c.readBufferThreshold > 0 && int64(c.inputBuffer.Len()) >= c.readBufferThreshold {
c.pauseRead()
return
}

vs[0] = c.inputBuffer.book(c.bookSize, c.maxSize)
return vs[:1]
}
Expand Down Expand Up @@ -123,6 +129,7 @@ func (c *connection) inputAck(n int) (err error) {
func (c *connection) outputs(vs [][]byte) (rs [][]byte, supportZeroCopy bool) {
if c.outputBuffer.IsEmpty() {
c.pauseWrite()
c.triggerWrite(nil)
return rs, c.supportZeroCopy
}
rs = c.outputBuffer.GetBytes(vs)
Expand All @@ -137,50 +144,43 @@ func (c *connection) outputAck(n int) (err error) {
}
if c.outputBuffer.IsEmpty() {
c.pauseWrite()
c.triggerWrite(nil)
}
return nil
}

/* The race description of operator event monitoring
- Pause operation will remove old event monitor of operator
- Resume operation will add new event monitor of operator
- Only poller could use Pause to remove event monitor, and poller already hold the op.do() locker
- Only user could use Resume, and user's operation maybe compete with poller's operation
- If competition happen, because of all resume operation will monitor all events, it's safe to do that with a race condition.
* If resume first and pause latter, poller will monitor the accurate events it needs.
* If pause first and resume latter, poller will monitor the duplicate events which will be removed after next poller triggered.
And poller will ensure to remove the duplicate events.
- If there is no readBufferThreshold option, the code path will be more simple and efficient.
*/

// pauseWrite removed the monitoring of write events.
// pauseWrite used in poller
func (c *connection) pauseWrite() {
switch c.operator.getMode() {
case opreadwrite:
c.operator.Control(PollRW2R)
case opwrite:
c.operator.Control(PollW2Hup)
}
c.triggerWrite(nil)
c.operator.Control(PollRW2R)
}

// resumeWrite add the monitoring of write events.
// resumeWrite used by users
func (c *connection) resumeWrite() {
c.operator.Control(PollR2RW)
}

// pauseRead removed the monitoring of read events.
// pauseRead used in poller
func (c *connection) pauseRead() {
// Note that the poller ensure that every fd should read all left data in socket buffer before detach it.
// So the operator mode should never be ophup.
var changeTo PollEvent
switch c.operator.getMode() {
case opread:
changeTo = PollR2Hup
case opreadwrite:
changeTo = PollRW2W
}
if changeTo > 0 && atomic.CompareAndSwapInt32(&c.operator.throttled, 0, 1) {
c.operator.Control(changeTo)
}
c.operator.Control(PollRW2W)
}

// resumeRead add the monitoring of read events.
// resumeRead used by users
func (c *connection) resumeRead() {
var changeTo PollEvent
switch c.operator.getMode() {
case ophup:
changeTo = PollHup2R
case opwrite:
changeTo = PollW2RW
}
if changeTo > 0 && atomic.CompareAndSwapInt32(&c.operator.throttled, 1, 0) {
c.operator.Control(changeTo)
}
c.operator.Control(PollW2RW)
}
10 changes: 1 addition & 9 deletions connection_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -784,18 +784,10 @@ func TestConnectionReadThresholdWithClosed(t *testing.T) {
MustNil(t, err)
t.Logf("read non-throttled data")

// continue read throttled data
// continue read throttled data with EOF
buf, err = connection.Reader().Next(5)
MustNil(t, err)
t.Logf("read throttled data: [%s]", buf)
Equal(t, len(buf), 5)
MustNil(t, err)
err = connection.Reader().Release()
MustNil(t, err)
Equal(t, connection.Reader().Len(), 0)

_, err = connection.Reader().Next(1)
Assert(t, errors.Is(err, ErrEOF))
trigger <- struct{}{}
return nil
}
Expand Down
19 changes: 6 additions & 13 deletions fd_operator.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ const (
opread int32 = 1
opwrite int32 = 2
opreadwrite int32 = 3
ophup int32 = 4
)

// FDOperator is a collection of operations on file descriptors.
Expand All @@ -51,8 +50,8 @@ type FDOperator struct {
// poll is the registered location of the file descriptor.
poll Poll

mode int32
throttled int32
// protect only detach once
detached int32

// private, used by operatorCache
next *FDOperator
Expand All @@ -61,21 +60,16 @@ type FDOperator struct {
}

func (op *FDOperator) Control(event PollEvent) error {
if event == PollDetach && atomic.AddInt32(&op.detached, 1) > 1 {
return nil
}
return op.poll.Control(op, event)
}

func (op *FDOperator) Free() {
op.poll.Free(op)
}

func (op *FDOperator) getMode() int32 {
return atomic.LoadInt32(&op.mode)
}

func (op *FDOperator) setMode(mode int32) {
atomic.StoreInt32(&op.mode, mode)
}

func (op *FDOperator) do() (can bool) {
return atomic.CompareAndSwapInt32(&op.state, 1, 2)
}
Expand Down Expand Up @@ -112,6 +106,5 @@ func (op *FDOperator) reset() {
op.Inputs, op.InputAck = nil, nil
op.Outputs, op.OutputAck = nil, nil
op.poll = nil
op.mode = 0
op.throttled = 0
op.detached = 0
}
22 changes: 7 additions & 15 deletions poll.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,31 +48,23 @@ type PollEvent int
const (
// PollReadable is used to monitor whether the FDOperator registered by
// listener and connection is readable or closed.
PollReadable PollEvent = 0x1
PollReadable PollEvent = iota + 1

// PollWritable is used to monitor whether the FDOperator created by the dialer is writable or closed.
// ET mode must be used (still need to poll hup after being writable)
PollWritable PollEvent = 0x2
PollWritable

// PollDetach is used to remove the FDOperator from poll.
PollDetach PollEvent = 0x3
PollDetach

// PollR2RW is used to monitor writable for FDOperator,
// which is only called when the socket write buffer is full.
PollR2RW PollEvent = 0x4
PollR2RW
// PollRW2R is used to remove the writable monitor of FDOperator, generally used with PollR2RW.
PollRW2R PollEvent = 0x5
PollRW2R

// PollRW2W is used to remove the readable monitor of FDOperator.
PollRW2W PollEvent = 0x6
PollRW2W
// PollW2RW is used to add the readable monitor of FDOperator, generally used with PollRW2W.
PollW2RW PollEvent = 0x7
PollW2Hup PollEvent = 0x8

// PollR2Hup is used to remove the readable monitor of FDOperator.
PollR2Hup PollEvent = 0x9
// PollHup2R is used to add the readable monitor of FDOperator, generally used with PollR2Hup.
PollHup2R PollEvent = 0xA
// PollHup2W is used to add the writeable monitor of FDOperator.
PollHup2W PollEvent = 0xB
PollW2RW
)
21 changes: 1 addition & 20 deletions poll_default_bsd.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ func (p *defaultPoll) Wait() error {
}
if triggerHup {
// if peer closed with throttled state, we should ensure we read all left data to avoid data loss
if (triggerRead || atomic.LoadInt32(&operator.throttled) > 0) && operator.Inputs != nil {
if triggerRead && operator.Inputs != nil {
var leftRead int
// read all left data if peer send and close
if leftRead, err = readall(operator, barriers[i]); err != nil && !errors.Is(err, ErrEOF) {
Expand Down Expand Up @@ -183,44 +183,25 @@ func (p *defaultPoll) Control(operator *FDOperator, event PollEvent) error {
switch event {
case PollReadable:
operator.inuse()
operator.setMode(opread)
evs[0].Filter, evs[0].Flags = syscall.EVFILT_READ, syscall.EV_ADD|syscall.EV_ENABLE
case PollWritable:
operator.inuse()
operator.setMode(opwrite)
evs[0].Filter, evs[0].Flags = syscall.EVFILT_WRITE, syscall.EV_ADD|syscall.EV_ENABLE
case PollDetach:
operator.setMode(ophup)
if operator.OnWrite != nil { // means WaitWrite finished
evs[0].Filter, evs[0].Flags = syscall.EVFILT_WRITE, syscall.EV_DELETE
} else {
evs[0].Filter, evs[0].Flags = syscall.EVFILT_READ, syscall.EV_DELETE
}
p.delOperator(operator)
case PollR2RW:
operator.setMode(opreadwrite)
evs[0].Filter, evs[0].Flags = syscall.EVFILT_WRITE, syscall.EV_ADD|syscall.EV_ENABLE
case PollRW2R:
operator.setMode(opread)
evs[0].Filter, evs[0].Flags = syscall.EVFILT_WRITE, syscall.EV_DELETE
case PollRW2W:
operator.setMode(opwrite)
evs[0].Filter, evs[0].Flags = syscall.EVFILT_READ, syscall.EV_DELETE
case PollW2RW:
operator.setMode(opreadwrite)
evs[0].Filter, evs[0].Flags = syscall.EVFILT_READ, syscall.EV_ADD|syscall.EV_ENABLE
case PollR2Hup:
operator.setMode(ophup)
evs[0].Filter, evs[0].Flags = syscall.EVFILT_READ, syscall.EV_DELETE
case PollW2Hup:
operator.setMode(ophup)
evs[0].Filter, evs[0].Flags = syscall.EVFILT_WRITE, syscall.EV_DELETE
case PollHup2R:
operator.setMode(opread)
evs[0].Filter, evs[0].Flags = syscall.EVFILT_READ, syscall.EV_ADD|syscall.EV_ENABLE
case PollHup2W:
operator.setMode(opwrite)
evs[0].Filter, evs[0].Flags = syscall.EVFILT_WRITE, syscall.EV_ADD|syscall.EV_ENABLE
}
_, err := syscall.Kevent(p.fd, evs, nil, nil)
return err
Expand Down
83 changes: 83 additions & 0 deletions poll_default_bsd_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
// Copyright 2024 CloudWeGo Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

//go:build darwin
// +build darwin

package netpoll

import (
"syscall"
"testing"
)

func TestKqueueEvent(t *testing.T) {
kqfd, err := syscall.Kqueue()
defer syscall.Close(kqfd)
_, err = syscall.Kevent(kqfd, []syscall.Kevent_t{{
Ident: 0,
Filter: syscall.EVFILT_USER,
Flags: syscall.EV_ADD | syscall.EV_CLEAR,
}}, nil, nil)
MustNil(t, err)

rfd, wfd := GetSysFdPairs()
defer syscall.Close(rfd)
defer syscall.Close(wfd)

// add read event
changes := make([]syscall.Kevent_t, 1)
changes[0].Ident = uint64(rfd)
changes[0].Filter = syscall.EVFILT_READ
changes[0].Flags = syscall.EV_ADD
_, err = syscall.Kevent(kqfd, changes, nil, nil)
MustNil(t, err)

// write
send := []byte("hello")
recv := make([]byte, 5)
_, err = syscall.Write(wfd, send)
MustNil(t, err)

// check readable
events := make([]syscall.Kevent_t, 128)
n, err := syscall.Kevent(kqfd, nil, events, nil)
MustNil(t, err)
Equal(t, n, 1)
Assert(t, events[0].Filter == syscall.EVFILT_READ)
// read
_, err = syscall.Read(rfd, recv)
MustNil(t, err)
Equal(t, string(recv), string(send))

// delete read
changes[0].Ident = uint64(rfd)
changes[0].Filter = syscall.EVFILT_READ
changes[0].Flags = syscall.EV_DELETE
_, err = syscall.Kevent(kqfd, changes, nil, nil)
MustNil(t, err)

// write
_, err = syscall.Write(wfd, send)
MustNil(t, err)

// check readable
n, err = syscall.Kevent(kqfd, nil, events, &syscall.Timespec{Sec: 1})
MustNil(t, err)
Equal(t, n, 0)
// read
_, err = syscall.Read(rfd, recv)
MustNil(t, err)
Equal(t, string(recv), string(send))
}
Loading

0 comments on commit c99b425

Please sign in to comment.