Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fn: optimize context guard #9361

Merged
merged 4 commits into from
Jan 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 23 additions & 26 deletions fn/context_guard.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,10 @@ func (g *ContextGuard) Quit() {
cancel()
}

// Clear cancelFns. It is safe to use nil, because no write
// operations to it can happen after g.quit is closed.
g.cancelFns = nil

close(g.quit)
})
}
Expand Down Expand Up @@ -149,7 +153,7 @@ func (g *ContextGuard) Create(ctx context.Context,
}

if opts.blocking {
g.ctxBlocking(ctx, cancel)
g.ctxBlocking(ctx)

return ctx, cancel
}
Expand All @@ -169,9 +173,10 @@ func (g *ContextGuard) Create(ctx context.Context,
return ctx, cancel
}

// ctxQuitUnsafe spins off a goroutine that will block until the passed context
// is cancelled or until the quit channel has been signaled after which it will
// call the passed cancel function and decrement the wait group.
// ctxQuitUnsafe increases the wait group counter, waits until the context is
// cancelled and decreases the wait group counter. It stores the passed cancel
// function and returns a wrapped version, which removed the stored one and
// calls it. The Quit method calls all the stored cancel functions.
//
// NOTE: the caller must hold the ContextGuard's mutex before calling this
// function.
Expand All @@ -181,35 +186,27 @@ func (g *ContextGuard) ctxQuitUnsafe(ctx context.Context,
cancel = g.addCancelFnUnsafe(cancel)

g.wg.Add(1)
go func() {
defer cancel()
defer g.wg.Done()

select {
case <-g.quit:

case <-ctx.Done():
}
}()
// We don't have to wait on g.quit here: g.quit can be closed only in
// the Quit method, which also closes the context we are waiting for.
context.AfterFunc(ctx, func() {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't really remove any overhead (since AterContext() calls propagateCancel() which creates the same "wait for context cancel" goroutine. But it's much simpler to read, so nice optimization in any case.
Also nice catch that with the g.quit now being un-exported, we can skip that code path as well.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

propagateCancel() which creates the same "wait for context cancel" goroutine

When the context is derived from cancelCtx (the most common case), propagateCancel avoids starting additional waiting goroutines. The cancelCtx type has an internal mechanism to track its child contexts, allowing it to recursively cancel all children when the parent context is canceled.

I added the test TestContextGuardCountGoroutines to verify that the Create method does not start any waiting goroutines. Without this optimization, the test would initiate 4000 new goroutines.

It's worth noting that a goroutine is indeed started later during context cancellation. However, it immediately calls wg.Done() and completes, so it doesn't exist throughout the context's lifetime. While the number of goroutines remains the same, their total runtime is significantly reduced, improving efficiency.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, right, that makes sense. Should've looked at the code in propagateCancel more closely I guess 🙈 Thanks for the clarification! And the additional unit test.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for this PR! @starius I understand the core change is to use context.AfterFunc instead of a goroutine for context cancellation handling, which should reduce overhead.

Without this optimization, the test would initiate 4000 new goroutines.

I'm trying to understand what you mean by this.
Are you referring to starting a goroutine for each context created here (in the test)?

// Create 1000 contexts of each type.
	for i := 0; i < 1000; i++ {
		_, _ = g.Create(ctx)
		_, _ = g.Create(ctx, WithBlockingCG())
		_, _ = g.Create(ctx, WithTimeoutCG())
		_, _ = g.Create(ctx, WithBlockingCG(), WithTimeoutCG())
	}

Additionally, the addition of the test is great. Overall, this looks like a solid improvement.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you referring to starting a goroutine for each context created here (in the test)?

Yes.

g.wg.Done()
})

return cancel
}

// ctxBlocking spins off a goroutine that will block until the passed context
// is cancelled after which it will call the passed cancel function and
// decrement the wait group.
func (g *ContextGuard) ctxBlocking(ctx context.Context,
cancel context.CancelFunc) {

// ctxBlocking increases the wait group counter, waits until the context is
// cancelled and decreases the wait group counter.
//
// NOTE: the caller must hold the ContextGuard's mutex before calling this
// function.
func (g *ContextGuard) ctxBlocking(ctx context.Context) {
g.wg.Add(1)
go func() {
defer cancel()
defer g.wg.Done()

select {
case <-ctx.Done():
}
}()
context.AfterFunc(ctx, func() {
g.wg.Done()
})
}

// addCancelFnUnsafe adds a context cancel function to the manager and returns a
Expand Down
42 changes: 42 additions & 0 deletions fn/context_guard_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,11 @@ package fn

import (
"context"
"runtime"
"testing"
"time"

"github.com/stretchr/testify/require"
)

// TestContextGuard tests the behaviour of the ContextGuard.
Expand Down Expand Up @@ -298,6 +301,12 @@ func TestContextGuard(t *testing.T) {
case <-time.After(time.Second):
t.Fatalf("timeout")
}

// Cancel the context.
cancel()

// Make sure wg's counter gets to 0 eventually.
g.WgWait()
})

// Test that if we add the CustomTimeoutCGOpt option, then the context
Expand Down Expand Up @@ -433,3 +442,36 @@ func TestContextGuard(t *testing.T) {
}
})
}

// TestContextGuardCountGoroutines makes sure that ContextGuard doesn't create
// any goroutines while waiting for contexts.
func TestContextGuardCountGoroutines(t *testing.T) {
// NOTE: t.Parallel() is not called in this test because it relies on an
// accurate count of active goroutines. Running other tests in parallel
// would introduce additional goroutines, leading to unreliable results.

g := NewContextGuard()
ellemouton marked this conversation as resolved.
Show resolved Hide resolved

ctx, cancel := context.WithCancel(context.Background())

// Count goroutines before contexts are created.
count1 := runtime.NumGoroutine()

// Create 1000 contexts of each type.
for i := 0; i < 1000; i++ {
_, _ = g.Create(ctx)
_, _ = g.Create(ctx, WithBlockingCG())
_, _ = g.Create(ctx, WithTimeoutCG())
_, _ = g.Create(ctx, WithBlockingCG(), WithTimeoutCG())
}

// Make sure no new goroutine was launched.
count2 := runtime.NumGoroutine()
require.LessOrEqual(t, count2, count1)

// Cancel root context.
cancel()

// Make sure wg's counter gets to 0 eventually.
g.WgWait()
}
Loading