Skip to content

Commit

Permalink
Add BlockUntilContext, which respects context cancellation.
Browse files Browse the repository at this point in the history
BlockUntil is easy to misjudge and when callers get that wrong, the test
blocks forever and eventually times out.

Also deletes notifyBlockers and its test, inlining this function at its
only call point.
  • Loading branch information
DPJacques committed Mar 25, 2023
1 parent a89700c commit bd7c02c
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 71 deletions.
69 changes: 47 additions & 22 deletions clockwork.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package clockwork

import (
"context"
"sort"
"sync"
"time"
Expand Down Expand Up @@ -119,22 +120,7 @@ func (fc *fakeClock) After(d time.Duration) <-chan time.Time {
return fc.NewTimer(d).Chan()
}

// notifyBlockers closes the receive channel for all blockers waiting for the
// current number of waiters (or fewer).
func (fc *fakeClock) notifyBlockers() {
var blocked []*blocker
count := len(fc.waiters)
for _, b := range fc.blockers {
if b.count <= count {
close(b.ch)
continue
}
blocked = append(blocked, b)
}
fc.blockers = blocked
}

// Sleep blocks until the given duration has past on the fakeClock.
// Sleep blocks until the given duration has passed on the fakeClock.
func (fc *fakeClock) Sleep(d time.Duration) {
<-fc.After(d)
}
Expand All @@ -146,7 +132,7 @@ func (fc *fakeClock) Now() time.Time {
return fc.time
}

// Since returns the duration that has past since the given time on the
// Since returns the duration that has passed since the given time on the
// fakeClock.
func (fc *fakeClock) Since(t time.Time) time.Duration {
return fc.Now().Sub(t)
Expand Down Expand Up @@ -227,21 +213,49 @@ func (fc *fakeClock) Advance(d time.Duration) {
}

// BlockUntil blocks until the fakeClock has the given number of waiters.
//
// Prefer BlockUntilContext, which offers context cancellation to prevent
// deadlock.
//
// Deprecation warning: This function might be deprecated in later versions.
func (fc *fakeClock) BlockUntil(n int) {
b := fc.newBlocker(n)
if b == nil {
return
}
<-b.ch
}

// BlockUntilContext blocks until the fakeClock has the given number of waiters
// or the context is cancelled.
func (fc *fakeClock) BlockUntilContext(ctx context.Context, n int) error {
b := fc.newBlocker(n)
if b == nil {
return nil
}

select {
case <-b.ch:
return nil
case <-ctx.Done():
return ctx.Err()
}
}

func (fc *fakeClock) newBlocker(n int) *blocker {
fc.l.Lock()
defer fc.l.Unlock()
// Fast path: we already have >= n waiters.
if len(fc.waiters) >= n {
fc.l.Unlock()
return
return nil
}
// Set up a new blocker to wait for more waiters.
b := &blocker{
count: n,
ch: make(chan struct{}),
}
fc.blockers = append(fc.blockers, b)
fc.l.Unlock()
<-b.ch
return b
}

// stop stops an expirer, returning true if the expirer was stopped.
Expand Down Expand Up @@ -291,7 +305,18 @@ func (fc *fakeClock) setExpirer(e expirer, d time.Duration) {
sort.Slice(fc.waiters, func(i int, j int) bool {
return fc.waiters[i].expiry().Before(fc.waiters[j].expiry())
})
fc.notifyBlockers()

// Notify blockers of our new waiter.
var blocked []*blocker
count := len(fc.waiters)
for _, b := range fc.blockers {
if b.count <= count {
close(b.ch)
continue
}
blocked = append(blocked, b)
}
fc.blockers = blocked
}

// firer is used by fakeTimer and fakeTicker used to help implement expirer.
Expand Down
79 changes: 30 additions & 49 deletions clockwork_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package clockwork

import (
"context"
"errors"
"reflect"
"testing"
"time"
Expand Down Expand Up @@ -87,55 +89,6 @@ func TestFakeClockAfter(t *testing.T) {
}
}

func TestNotifyBlockers(t *testing.T) {
t.Parallel()
b1 := &blocker{1, make(chan struct{})}
b2 := &blocker{2, make(chan struct{})}
b3 := &blocker{5, make(chan struct{})}
b4 := &blocker{10, make(chan struct{})}
b5 := &blocker{10, make(chan struct{})}
fc := fakeClock{
blockers: []*blocker{b1, b2, b3, b4, b5},
waiters: []expirer{nil, nil},
}
fc.notifyBlockers()
if n := len(fc.blockers); n != 3 {
t.Fatalf("got %d blockers, want %d", n, 3)
}
select {
case <-b1.ch:
case <-time.After(time.Second):
t.Fatalf("timed out waiting for channel close!")
}
select {
case <-b2.ch:
case <-time.After(time.Second):
t.Fatalf("timed out waiting for channel close!")
}
for len(fc.waiters) < 10 {
fc.waiters = append(fc.waiters, nil)
}
fc.notifyBlockers()
if n := len(fc.blockers); n != 0 {
t.Fatalf("got %d blockers, want %d", n, 0)
}
select {
case <-b3.ch:
case <-time.After(time.Second):
t.Fatalf("timed out waiting for channel close!")
}
select {
case <-b4.ch:
case <-time.After(time.Second):
t.Fatalf("timed out waiting for channel close!")
}
select {
case <-b5.ch:
case <-time.After(time.Second):
t.Fatalf("timed out waiting for channel close!")
}
}

func TestNewFakeClock(t *testing.T) {
t.Parallel()
fc := NewFakeClock()
Expand Down Expand Up @@ -186,6 +139,34 @@ func TestTwoBlockersOneBlock(t *testing.T) {
ft2.Stop()
}

func TestBlockUntilContext(t *testing.T) {
t.Parallel()
fc := &fakeClock{}

ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()

blockCtx, cancelBlock := context.WithCancel(ctx)
errCh := make(chan error)

go func() {
select {
case errCh <- fc.BlockUntilContext(blockCtx, 2):
case <-ctx.Done(): // Error case, captured below.
}
}()
cancelBlock()

select {
case err := <-errCh:
if !errors.Is(err, context.Canceled) {
t.Errorf("BlockUntilContext returned %v, want context.Canceled.", err)
}
case <-ctx.Done():
t.Errorf("Never receved error on context cancellation.")
}
}

func TestAfterDeliveryInOrder(t *testing.T) {
t.Parallel()
fc := &fakeClock{}
Expand Down

0 comments on commit bd7c02c

Please sign in to comment.