diff --git a/internal/quic/conn.go b/internal/quic/conn.go index 5601b989e3..90e6739630 100644 --- a/internal/quic/conn.go +++ b/internal/quic/conn.go @@ -7,6 +7,7 @@ package quic import ( + "context" "crypto/tls" "errors" "fmt" @@ -71,6 +72,7 @@ type connTestHooks interface { nextMessage(msgc chan any, nextTimeout time.Time) (now time.Time, message any) handleTLSEvent(tls.QUICEvent) newConnID(seq int64) ([]byte, error) + waitAndLockGate(ctx context.Context, g *gate) error } func newConn(now time.Time, side connSide, initialConnID []byte, peerAddr netip.AddrPort, config *Config, l connListener, hooks connTestHooks) (*Conn, error) { @@ -299,6 +301,13 @@ func (c *Conn) runOnLoop(f func(now time.Time, c *Conn)) error { return nil } +func (c *Conn) waitAndLockGate(ctx context.Context, g *gate) error { + if c.testHooks != nil { + return c.testHooks.waitAndLockGate(ctx, g) + } + return g.waitAndLockContext(ctx) +} + // abort terminates a connection with an error. func (c *Conn) abort(now time.Time, err error) { if c.errForPeer == nil { diff --git a/internal/quic/conn_async_test.go b/internal/quic/conn_async_test.go new file mode 100644 index 0000000000..2078325a53 --- /dev/null +++ b/internal/quic/conn_async_test.go @@ -0,0 +1,185 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build go1.21 + +package quic + +import ( + "context" + "errors" + "fmt" + "path/filepath" + "runtime" + "sync" +) + +// asyncTestState permits handling asynchronous operations in a synchronous test. +// +// For example, a test may want to write to a stream and observe that +// STREAM frames are sent with the contents of the write in response +// to MAX_STREAM_DATA frames received from the peer. +// The Stream.Write is an asynchronous operation, but the test is simpler +// if we can start the write, observe the first STREAM frame sent, +// send a MAX_STREAM_DATA frame, observe the next STREAM frame sent, etc. +// +// We do this by instrumenting points where operations can block. +// We start async operations like Write in a goroutine, +// and wait for the operation to either finish or hit a blocking point. +// When the connection event loop is idle, we check a list of +// blocked operations to see if any can be woken. +type asyncTestState struct { + mu sync.Mutex + notify chan struct{} + blocked map[*blockedAsync]struct{} +} + +// An asyncOp is an asynchronous operation that results in (T, error). +type asyncOp[T any] struct { + v T + err error + + caller string + state *asyncTestState + donec chan struct{} + cancelFunc context.CancelFunc +} + +// cancel cancels the async operation's context, and waits for +// the operation to complete. +func (a *asyncOp[T]) cancel() { + select { + case <-a.donec: + return // already done + default: + } + a.cancelFunc() + <-a.state.notify + select { + case <-a.donec: + default: + panic(fmt.Errorf("%v: async op failed to finish after being canceled", a.caller)) + } +} + +var errNotDone = errors.New("async op is not done") + +// result returns the result of the async operation. +// It returns errNotDone if the operation is still in progress. +// +// Note that unlike a traditional async/await, this doesn't block +// waiting for the operation to complete. Since tests have full +// control over the progress of operations, an asyncOp can only +// become done in reaction to the test taking some action. +func (a *asyncOp[T]) result() (v T, err error) { + select { + case <-a.donec: + return a.v, a.err + default: + return v, errNotDone + } +} + +// A blockedAsync is a blocked async operation. +// +// Currently, the only type of blocked operation is one waiting on a gate. +type blockedAsync struct { + g *gate + donec chan struct{} // closed when the operation is unblocked +} + +type asyncContextKey struct{} + +// runAsync starts an asynchronous operation. +// +// The function f should call a blocking function such as +// Stream.Write or Conn.AcceptStream and return its result. +// It must use the provided context. +func runAsync[T any](ts *testConn, f func(context.Context) (T, error)) *asyncOp[T] { + as := &ts.asyncTestState + if as.notify == nil { + as.notify = make(chan struct{}) + as.blocked = make(map[*blockedAsync]struct{}) + } + _, file, line, _ := runtime.Caller(1) + ctx := context.WithValue(context.Background(), asyncContextKey{}, true) + ctx, cancel := context.WithCancel(ctx) + a := &asyncOp[T]{ + state: as, + caller: fmt.Sprintf("%v:%v", filepath.Base(file), line), + donec: make(chan struct{}), + cancelFunc: cancel, + } + go func() { + a.v, a.err = f(ctx) + close(a.donec) + as.notify <- struct{}{} + }() + ts.t.Cleanup(func() { + if _, err := a.result(); err == errNotDone { + ts.t.Errorf("%v: async operation is still executing at end of test", a.caller) + a.cancel() + } + }) + // Wait for the operation to either finish or block. + <-as.notify + return a +} + +// waitAndLockGate replaces gate.waitAndLock in tests. +func (as *asyncTestState) waitAndLockGate(ctx context.Context, g *gate) error { + if g.lockIfSet() { + // Gate can be acquired without blocking. + return nil + } + if err := ctx.Err(); err != nil { + // Context has already expired. + return err + } + if ctx.Value(asyncContextKey{}) == nil { + // Context is not one that we've created, and hasn't expired. + // This probably indicates that we've tried to perform a + // blocking operation without using the async test harness here, + // which may have unpredictable results. + panic("blocking async point with unexpected Context") + } + // Record this as a pending blocking operation. + as.mu.Lock() + b := &blockedAsync{ + g: g, + donec: make(chan struct{}), + } + as.blocked[b] = struct{}{} + as.mu.Unlock() + // Notify the creator of the operation that we're blocked, + // and wait to be woken up. + as.notify <- struct{}{} + select { + case <-b.donec: + case <-ctx.Done(): + return ctx.Err() + } + return nil +} + +// wakeAsync tries to wake up a blocked async operation. +// It returns true if one was woken, false otherwise. +func (as *asyncTestState) wakeAsync() bool { + as.mu.Lock() + var woken *blockedAsync + for w := range as.blocked { + if w.g.lockIfSet() { + woken = w + delete(as.blocked, woken) + break + } + } + as.mu.Unlock() + if woken == nil { + return false + } + close(woken.donec) + <-as.notify // must not hold as.mu while blocked here + return true +} diff --git a/internal/quic/conn_streams.go b/internal/quic/conn_streams.go index 82e9028609..f626323b5a 100644 --- a/internal/quic/conn_streams.go +++ b/internal/quic/conn_streams.go @@ -36,7 +36,7 @@ func (c *Conn) streamsInit() { // AcceptStream waits for and returns the next stream created by the peer. func (c *Conn) AcceptStream(ctx context.Context) (*Stream, error) { - return c.streams.queue.get(ctx) + return c.streams.queue.getWithHooks(ctx, c.testHooks) } // NewStream creates a stream. diff --git a/internal/quic/conn_streams_test.go b/internal/quic/conn_streams_test.go index 8481a604c5..bcbbe81ce3 100644 --- a/internal/quic/conn_streams_test.go +++ b/internal/quic/conn_streams_test.go @@ -95,6 +95,35 @@ func TestStreamsAccept(t *testing.T) { } } +func TestStreamsBlockingAccept(t *testing.T) { + tc := newTestConn(t, serverSide) + tc.handshake() + + a := runAsync(tc, func(ctx context.Context) (*Stream, error) { + return tc.conn.AcceptStream(ctx) + }) + if _, err := a.result(); err != errNotDone { + tc.t.Fatalf("AcceptStream() = _, %v; want errNotDone", err) + } + + sid := newStreamID(clientSide, bidiStream, 0) + tc.writeFrames(packetType1RTT, + debugFrameStream{ + id: sid, + }) + + s, err := a.result() + if err != nil { + t.Fatalf("conn.AcceptStream() = _, %v, want stream", err) + } + if got, want := s.id, sid; got != want { + t.Fatalf("conn.AcceptStream() = stream %v, want %v", got, want) + } + if got, want := s.IsReadOnly(), false; got != want { + t.Fatalf("s.IsReadOnly() = %v, want %v", got, want) + } +} + func TestStreamsStreamNotCreated(t *testing.T) { // "An endpoint MUST terminate the connection with error STREAM_STATE_ERROR // if it receives a STREAM frame for a locally initiated stream that has diff --git a/internal/quic/conn_test.go b/internal/quic/conn_test.go index 110b0a9f90..5aad69f4d1 100644 --- a/internal/quic/conn_test.go +++ b/internal/quic/conn_test.go @@ -144,6 +144,8 @@ type testConn struct { // Frame types to ignore in tests. ignoreFrames map[byte]bool + + asyncTestState } type keyData struct { @@ -700,21 +702,26 @@ func (tc *testConnHooks) handleTLSEvent(e tls.QUICEvent) { // nextMessage is called by the Conn's event loop to request its next event. func (tc *testConnHooks) nextMessage(msgc chan any, timer time.Time) (now time.Time, m any) { tc.timer = timer - if !timer.IsZero() && !timer.After(tc.now) { - if timer.Equal(tc.timerLastFired) { - // If the connection timer fires at time T, the Conn should take some - // action to advance the timer into the future. If the Conn reschedules - // the timer for the same time, it isn't making progress and we have a bug. - tc.t.Errorf("connection timer spinning; now=%v timer=%v", tc.now, timer) - } else { - tc.timerLastFired = timer - return tc.now, timerEvent{} + for { + if !timer.IsZero() && !timer.After(tc.now) { + if timer.Equal(tc.timerLastFired) { + // If the connection timer fires at time T, the Conn should take some + // action to advance the timer into the future. If the Conn reschedules + // the timer for the same time, it isn't making progress and we have a bug. + tc.t.Errorf("connection timer spinning; now=%v timer=%v", tc.now, timer) + } else { + tc.timerLastFired = timer + return tc.now, timerEvent{} + } + } + select { + case m := <-msgc: + return tc.now, m + default: + } + if !tc.wakeAsync() { + break } - } - select { - case m := <-msgc: - return tc.now, m - default: } // If the message queue is empty, then the conn is idle. if tc.idlec != nil { diff --git a/internal/quic/queue.go b/internal/quic/queue.go index 9bb71ca3f4..489721a8af 100644 --- a/internal/quic/queue.go +++ b/internal/quic/queue.go @@ -45,8 +45,20 @@ func (q *queue[T]) put(v T) bool { // get removes the first item from the queue, blocking until ctx is done, an item is available, // or the queue is closed. func (q *queue[T]) get(ctx context.Context) (T, error) { + return q.getWithHooks(ctx, nil) +} + +// getWithHooks is get, but uses testHooks for locking when non-nil. +// This is a bit of an layer violation, but a simplification overall. +func (q *queue[T]) getWithHooks(ctx context.Context, testHooks connTestHooks) (T, error) { var zero T - if err := q.gate.waitAndLockContext(ctx); err != nil { + var err error + if testHooks != nil { + err = testHooks.waitAndLockGate(ctx, &q.gate) + } else { + err = q.gate.waitAndLockContext(ctx) + } + if err != nil { return zero, err } defer q.unlock()