Skip to content

Commit

Permalink
context: add AfterFunc
Browse files Browse the repository at this point in the history
Add an AfterFunc function, which registers a function to run after
a context has been canceled.

Add support for contexts that implement an AfterFunc method, which
can be used to avoid the need to start a new goroutine watching
the Done channel when propagating cancellation signals.

Fixes #57928

Change-Id: If0b2cdcc4332961276a1ff57311338e74916259c
Reviewed-on: https://go-review.googlesource.com/c/go/+/482695
TryBot-Result: Gopher Robot <gobot@golang.org>
Run-TryBot: Damien Neil <dneil@google.com>
Reviewed-by: Sameer Ajmani <sameer@golang.org>
  • Loading branch information
neild committed Apr 19, 2023
1 parent 9d53d7a commit 54d4299
Show file tree
Hide file tree
Showing 6 changed files with 533 additions and 44 deletions.
1 change: 1 addition & 0 deletions api/next/57928.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
pkg context, func AfterFunc(Context, func()) func() bool #57928
141 changes: 141 additions & 0 deletions src/context/afterfunc_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
// Copyright 2023 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package context_test

import (
"context"
"sync"
"testing"
"time"
)

// afterFuncContext is a context that's not one of the types
// defined in context.go, that supports registering AfterFuncs.
type afterFuncContext struct {
mu sync.Mutex
afterFuncs map[*struct{}]func()
done chan struct{}
err error
}

func newAfterFuncContext() context.Context {
return &afterFuncContext{}
}

func (c *afterFuncContext) Deadline() (time.Time, bool) {
return time.Time{}, false
}

func (c *afterFuncContext) Done() <-chan struct{} {
c.mu.Lock()
defer c.mu.Unlock()
if c.done == nil {
c.done = make(chan struct{})
}
return c.done
}

func (c *afterFuncContext) Err() error {
c.mu.Lock()
defer c.mu.Unlock()
return c.err
}

func (c *afterFuncContext) Value(key any) any {
return nil
}

func (c *afterFuncContext) AfterFunc(f func()) func() bool {
c.mu.Lock()
defer c.mu.Unlock()
k := &struct{}{}
if c.afterFuncs == nil {
c.afterFuncs = make(map[*struct{}]func())
}
c.afterFuncs[k] = f
return func() bool {
c.mu.Lock()
defer c.mu.Unlock()
_, ok := c.afterFuncs[k]
delete(c.afterFuncs, k)
return ok
}
}

func (c *afterFuncContext) cancel(err error) {
c.mu.Lock()
defer c.mu.Unlock()
if c.err != nil {
return
}
c.err = err
for _, f := range c.afterFuncs {
go f()
}
c.afterFuncs = nil
}

func TestCustomContextAfterFuncCancel(t *testing.T) {
ctx0 := &afterFuncContext{}
ctx1, cancel := context.WithCancel(ctx0)
defer cancel()
ctx0.cancel(context.Canceled)
<-ctx1.Done()
}

func TestCustomContextAfterFuncTimeout(t *testing.T) {
ctx0 := &afterFuncContext{}
ctx1, cancel := context.WithTimeout(ctx0, veryLongDuration)
defer cancel()
ctx0.cancel(context.Canceled)
<-ctx1.Done()
}

func TestCustomContextAfterFuncAfterFunc(t *testing.T) {
ctx0 := &afterFuncContext{}
donec := make(chan struct{})
stop := context.AfterFunc(ctx0, func() {
close(donec)
})
defer stop()
ctx0.cancel(context.Canceled)
<-donec
}

func TestCustomContextAfterFuncUnregisterCancel(t *testing.T) {
ctx0 := &afterFuncContext{}
_, cancel := context.WithCancel(ctx0)
if got, want := len(ctx0.afterFuncs), 1; got != want {
t.Errorf("after WithCancel(ctx0): ctx0 has %v afterFuncs, want %v", got, want)
}
cancel()
if got, want := len(ctx0.afterFuncs), 0; got != want {
t.Errorf("after canceling WithCancel(ctx0): ctx0 has %v afterFuncs, want %v", got, want)
}
}

func TestCustomContextAfterFuncUnregisterTimeout(t *testing.T) {
ctx0 := &afterFuncContext{}
_, cancel := context.WithTimeout(ctx0, veryLongDuration)
if got, want := len(ctx0.afterFuncs), 1; got != want {
t.Errorf("after WithTimeout(ctx0, d): ctx0 has %v afterFuncs, want %v", got, want)
}
cancel()
if got, want := len(ctx0.afterFuncs), 0; got != want {
t.Errorf("after canceling WithTimeout(ctx0, d): ctx0 has %v afterFuncs, want %v", got, want)
}
}

func TestCustomContextAfterFuncUnregisterAfterFunc(t *testing.T) {
ctx0 := &afterFuncContext{}
stop := context.AfterFunc(ctx0, func() {})
if got, want := len(ctx0.afterFuncs), 1; got != want {
t.Errorf("after AfterFunc(ctx0, f): ctx0 has %v afterFuncs, want %v", got, want)
}
stop()
if got, want := len(ctx0.afterFuncs), 0; got != want {
t.Errorf("after stopping AfterFunc(ctx0, f): ctx0 has %v afterFuncs, want %v", got, want)
}
}
167 changes: 126 additions & 41 deletions src/context/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -269,8 +269,8 @@ func withCancel(parent Context) *cancelCtx {
if parent == nil {
panic("cannot create context from nil parent")
}
c := &cancelCtx{Context: parent}
propagateCancel(parent, c)
c := &cancelCtx{}
c.propagateCancel(parent, c)
return c
}

Expand All @@ -289,48 +289,72 @@ func Cause(c Context) error {
return nil
}

// goroutines counts the number of goroutines ever created; for testing.
var goroutines atomic.Int32

// propagateCancel arranges for child to be canceled when parent is.
func propagateCancel(parent Context, child canceler) {
done := parent.Done()
if done == nil {
return // parent is never canceled
// AfterFunc arranges to call f in its own goroutine after ctx is done
// (cancelled or timed out).
// If ctx is already done, AfterFunc calls f immediately in its own goroutine.
//
// Multiple calls to AfterFunc on a context operate independently;
// one does not replace another.
//
// Calling the returned stop function stops the association of ctx with f.
// It returns true if the call stopped f from being run.
// If stop returns false,
// either the context is done and f has been started in its own goroutine;
// or f was already stopped.
// The stop function does not wait for f to complete before returning.
// If the caller needs to know whether f is completed,
// it must coordinate with f explicitly.
//
// If ctx has a "AfterFunc(func()) func() bool" method,
// AfterFunc will use it to schedule the call.
func AfterFunc(ctx Context, f func()) (stop func() bool) {
a := &afterFuncCtx{
f: f,
}
a.cancelCtx.propagateCancel(ctx, a)
return func() bool {
stopped := false
a.once.Do(func() {
stopped = true
})
if stopped {
a.cancel(true, Canceled, nil)
}
return stopped
}
}

select {
case <-done:
// parent is already canceled
child.cancel(false, parent.Err(), Cause(parent))
return
default:
}
type afterFuncer interface {
AfterFunc(func()) func() bool
}

if p, ok := parentCancelCtx(parent); ok {
p.mu.Lock()
if p.err != nil {
// parent has already been canceled
child.cancel(false, p.err, p.cause)
} else {
if p.children == nil {
p.children = make(map[canceler]struct{})
}
p.children[child] = struct{}{}
}
p.mu.Unlock()
} else {
goroutines.Add(1)
go func() {
select {
case <-parent.Done():
child.cancel(false, parent.Err(), Cause(parent))
case <-child.Done():
}
}()
type afterFuncCtx struct {
cancelCtx
once sync.Once // either starts running f or stops f from running
f func()
}

func (a *afterFuncCtx) cancel(removeFromParent bool, err, cause error) {
a.cancelCtx.cancel(false, err, cause)
if removeFromParent {
removeChild(a.Context, a)
}
a.once.Do(func() {
go a.f()
})
}

// A stopCtx is used as the parent context of a cancelCtx when
// an AfterFunc has been registered with the parent.
// It holds the stop function used to unregister the AfterFunc.
type stopCtx struct {
Context
stop func() bool
}

// goroutines counts the number of goroutines ever created; for testing.
var goroutines atomic.Int32

// &cancelCtxKey is the key that a cancelCtx returns itself for.
var cancelCtxKey int

Expand Down Expand Up @@ -358,6 +382,10 @@ func parentCancelCtx(parent Context) (*cancelCtx, bool) {

// removeChild removes a context from its parent.
func removeChild(parent Context, child canceler) {
if s, ok := parent.(stopCtx); ok {
s.stop()
return
}
p, ok := parentCancelCtx(parent)
if !ok {
return
Expand Down Expand Up @@ -424,6 +452,64 @@ func (c *cancelCtx) Err() error {
return err
}

// propagateCancel arranges for child to be canceled when parent is.
// It sets the parent context of cancelCtx.
func (c *cancelCtx) propagateCancel(parent Context, child canceler) {
c.Context = parent

done := parent.Done()
if done == nil {
return // parent is never canceled
}

select {
case <-done:
// parent is already canceled
child.cancel(false, parent.Err(), Cause(parent))
return
default:
}

if p, ok := parentCancelCtx(parent); ok {
// parent is a *cancelCtx, or derives from one.
p.mu.Lock()
if p.err != nil {
// parent has already been canceled
child.cancel(false, p.err, p.cause)
} else {
if p.children == nil {
p.children = make(map[canceler]struct{})
}
p.children[child] = struct{}{}
}
p.mu.Unlock()
return
}

if a, ok := parent.(afterFuncer); ok {
// parent implements an AfterFunc method.
c.mu.Lock()
stop := a.AfterFunc(func() {
child.cancel(false, parent.Err(), Cause(parent))
})
c.Context = stopCtx{
Context: parent,
stop: stop,
}
c.mu.Unlock()
return
}

goroutines.Add(1)
go func() {
select {
case <-parent.Done():
child.cancel(false, parent.Err(), Cause(parent))
case <-child.Done():
}
}()
}

type stringer interface {
String() string
}
Expand Down Expand Up @@ -533,10 +619,9 @@ func WithDeadlineCause(parent Context, d time.Time, cause error) (Context, Cance
return WithCancel(parent)
}
c := &timerCtx{
cancelCtx: cancelCtx{Context: parent},
deadline: d,
deadline: d,
}
propagateCancel(parent, c)
c.cancelCtx.propagateCancel(parent, c)
dur := time.Until(d)
if dur <= 0 {
c.cancel(true, DeadlineExceeded, cause) // deadline has already passed
Expand Down
Loading

0 comments on commit 54d4299

Please sign in to comment.