Skip to content

Commit

Permalink
protect against concurrent use of Stream.Read (#3380)
Browse files Browse the repository at this point in the history
  • Loading branch information
marten-seemann authored Apr 25, 2022
1 parent 823c609 commit ec118e4
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 1 deletion.
10 changes: 9 additions & 1 deletion receive_stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ type receiveStream struct {
resetRemotely bool // set when HandleResetStreamFrame() is called

readChan chan struct{}
readOnce chan struct{} // cap: 1, to protect against concurrent use of Read
deadline time.Time

flowController flowcontrol.StreamFlowController
Expand All @@ -70,6 +71,7 @@ func newReceiveStream(
flowController: flowController,
frameQueue: newFrameSorter(),
readChan: make(chan struct{}, 1),
readOnce: make(chan struct{}, 1),
finalOffset: protocol.MaxByteCount,
version: version,
}
Expand All @@ -81,6 +83,12 @@ func (s *receiveStream) StreamID() protocol.StreamID {

// Read implements io.Reader. It is not thread safe!
func (s *receiveStream) Read(p []byte) (int, error) {
// Concurrent use of Read is not permitted (and doesn't make any sense),
// but sometimes people do it anyway.
// Make sure that we only execute one call at any given time to avoid hard to debug failures.
s.readOnce <- struct{}{}
defer func() { <-s.readOnce }()

s.mutex.Lock()
completed, n, err := s.readImpl(p)
s.mutex.Unlock()
Expand All @@ -105,7 +113,7 @@ func (s *receiveStream) readImpl(p []byte) (bool /*stream completed */, int, err
return false, 0, s.closeForShutdownErr
}

bytesRead := 0
var bytesRead int
var deadlineTimer *utils.Timer
for bytesRead < len(p) {
if s.currentFrame == nil || s.readPosInFrame >= len(s.currentFrame) {
Expand Down
39 changes: 39 additions & 0 deletions receive_stream_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ import (
"errors"
"io"
"runtime"
"sync"
"sync/atomic"
"time"

"github.com/golang/mock/gomock"
Expand Down Expand Up @@ -403,6 +405,43 @@ var _ = Describe("Receive Stream", func() {
Expect(n).To(BeZero())
Expect(err).To(MatchError(io.EOF))
})

// Calling Read concurrently doesn't make any sense (and is forbidden),
// but we still want to make sure that we don't complete the stream more than once
// if the user misuses our API.
// This would lead to an INTERNAL_ERROR ("tried to delete unknown outgoing stream"),
// which can be hard to debug.
// Note that even without the protection built into the receiveStream, this test
// is very timing-dependent, and would need to run a few hundred times to trigger the failure.
It("handles concurrent reads", func() {
mockFC.EXPECT().UpdateHighestReceived(protocol.ByteCount(6), gomock.Any()).AnyTimes()
var bytesRead protocol.ByteCount
mockFC.EXPECT().AddBytesRead(gomock.Any()).Do(func(n protocol.ByteCount) { bytesRead += n }).AnyTimes()

var numCompleted int32
mockSender.EXPECT().onStreamCompleted(streamID).Do(func(protocol.StreamID) {
atomic.AddInt32(&numCompleted, 1)
}).AnyTimes()
const num = 3
var wg sync.WaitGroup
wg.Add(num)
for i := 0; i < num; i++ {
go func() {
defer wg.Done()
defer GinkgoRecover()
_, err := str.Read(make([]byte, 8))
Expect(err).To(MatchError(io.EOF))
}()
}
str.handleStreamFrame(&wire.StreamFrame{
Offset: 0,
Data: []byte("foobar"),
Fin: true,
})
wg.Wait()
Expect(bytesRead).To(BeEquivalentTo(6))
Expect(atomic.LoadInt32(&numCompleted)).To(BeEquivalentTo(1))
})
})

It("closes when CloseRemote is called", func() {
Expand Down

0 comments on commit ec118e4

Please sign in to comment.