Skip to content

Commit

Permalink
contextutil: add CancelWithReason
Browse files Browse the repository at this point in the history
contextutil.WithCancel now returns a new context type that holds its
reason for cancellation when canceled with CancelWithReason instead of
its normal cancelfunc.

Release note: None
  • Loading branch information
jordanlewis committed May 22, 2019
1 parent b4fabf9 commit c03da61
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 0 deletions.
25 changes: 25 additions & 0 deletions pkg/util/contextutil/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@ import (
"fmt"
"net"
"runtime/debug"
"sync/atomic"
"time"
"unsafe"

"github.com/cockroachdb/cockroach/pkg/util/log"
)
Expand All @@ -29,6 +31,29 @@ func WithCancel(parent context.Context) (context.Context, context.CancelFunc) {
return wrap(context.WithCancel(parent))
}

type reasonKey struct{}

type CancelWithReasonFunc func(error)

func WithCancelReason(ctx context.Context) (context.Context, CancelWithReasonFunc) {
ptr := new(unsafe.Pointer)
ctx = context.WithValue(ctx, reasonKey{}, ptr)
ctx, cancel := wrap(context.WithCancel(ctx))
return ctx, func(reason error) {
atomic.StorePointer(ptr, unsafe.Pointer(&reason))
cancel()
}
}

func GetCancelReason(ctx context.Context) error {
i := ctx.Value(reasonKey{})
switch t := i.(type) {
case *unsafe.Pointer:
return *(*error)(atomic.LoadPointer(t))
}
return nil
}

func wrap(ctx context.Context, cancel context.CancelFunc) (context.Context, context.CancelFunc) {
if !log.V(1) {
return ctx, cancel
Expand Down
27 changes: 27 additions & 0 deletions pkg/util/contextutil/context_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"time"

"github.com/pkg/errors"
"github.com/stretchr/testify/assert"
)

func TestRunWithTimeout(t *testing.T) {
Expand Down Expand Up @@ -95,3 +96,29 @@ func TestRunWithTimeoutWithoutDeadlineExceeded(t *testing.T) {
"returned error")
}
}

func TestCancelWithReason(t *testing.T) {
ctx := context.Background()

var cancel CancelWithReasonFunc
ctx, cancel = WithCancelReason(ctx)

e := errors.New("hodor")
go func() {
cancel(e)
}()

loop:
for true {
select {
case <-ctx.Done():
break loop
}
time.Sleep(time.Duration(100 * time.Millisecond))
}

expected := "context canceled"
found := ctx.Err().Error()
assert.Equal(t, expected, found)
assert.Equal(t, e, GetCancelReason(ctx))
}

0 comments on commit c03da61

Please sign in to comment.