diff --git a/context/pool.go b/context/pool.go new file mode 100644 index 0000000..3340706 --- /dev/null +++ b/context/pool.go @@ -0,0 +1,98 @@ +/* +Copyright 2023 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package context + +import ( + "context" + "sync" +) + +// Pool is a pool of contexts whereby the callee context is only cancelled when +// all callers in the pool are done. Contexts added after the pool is created +// are also tracked, but will be ignored if the pool is already cancelled. +type Pool struct { + context.Context + closed chan struct{} + pool []<-chan struct{} + lock sync.RWMutex +} + +// NewPool creates a new context pool with the given contexts. The returned +// context is cancelled when all contexts in the pool are done. Added contexts +// are ignored if the pool is already cancelled or if all current contexts in +// the pool are done. +func NewPool(ctx ...context.Context) *Pool { + callee, cancel := context.WithCancel(context.Background()) + p := &Pool{ + Context: callee, + pool: make([]<-chan struct{}, 0, len(ctx)), + closed: make(chan struct{}), + } + + for i := range ctx { + select { + case <-ctx[i].Done(): + default: + p.pool = append(p.pool, ctx[i].Done()) + } + } + + p.lock.RLock() + go func() { + defer cancel() + defer p.lock.RUnlock() + for i := 0; i < len(p.pool); i++ { + ch := p.pool[i] + p.lock.RUnlock() + select { + case <-ch: + case <-p.closed: + } + p.lock.RLock() + } + }() + + return p +} + +// Add adds a context to the pool. The context is ignored if the pool is +// already cancelled or if all current contexts in the pool are done. +func (p *Pool) Add(ctx context.Context) *Pool { + p.lock.Lock() + defer p.lock.Unlock() + select { + case <-p.Done(): + case <-p.closed: + default: + p.pool = append(p.pool, ctx.Done()) + } + return p +} + +// Cancel cancels the pool. Removes all contexts from the pool. +func (p *Pool) Cancel() { + p.lock.Lock() + defer p.lock.Unlock() + if p.pool != nil { + close(p.closed) + p.pool = nil + } +} + +// Size returns the number of contexts in the pool. +func (p *Pool) Size() int { + p.lock.RLock() + defer p.lock.RUnlock() + return len(p.pool) +} diff --git a/context/pool_test.go b/context/pool_test.go new file mode 100644 index 0000000..6d25ee3 --- /dev/null +++ b/context/pool_test.go @@ -0,0 +1,134 @@ +/* +Copyright 2023 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package context + +import ( + "context" + "math/rand" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func Test_Pool(t *testing.T) { + var _ context.Context = &Pool{} + + t.Run("a pool with no context will always be done", func(t *testing.T) { + t.Parallel() + pool := NewPool() + select { + case <-pool.Done(): + case <-time.After(time.Second): + t.Error("expected context pool to be cancelled") + } + }) + + t.Run("a cancelled context given to pool, should have pool cancelled", func(t *testing.T) { + t.Parallel() + ctx, cancel := context.WithCancel(context.Background()) + cancel() + pool := NewPool(ctx) + select { + case <-pool.Done(): + case <-time.After(time.Second): + t.Error("expected context pool to be cancelled") + } + }) + + t.Run("a cancelled context given to pool, given a new context, should still have pool cancelled", func(t *testing.T) { + t.Parallel() + ctx, cancel := context.WithCancel(context.Background()) + cancel() + pool := NewPool(ctx) + pool.Add(context.Background()) + select { + case <-pool.Done(): + case <-time.After(time.Second): + t.Error("expected context pool to be cancelled") + } + }) + + t.Run("pool with multiple contexts should return once all contexts have been cancelled", func(t *testing.T) { + t.Parallel() + var ctx [50]context.Context + var cancel [50]context.CancelFunc + + ctx[0], cancel[0] = context.WithCancel(context.Background()) + pool := NewPool(ctx[0]) + + for i := 1; i < 50; i++ { + ctx[i], cancel[i] = context.WithCancel(context.Background()) + pool.Add(ctx[i]) + } + + //nolint:gosec + r := rand.New(rand.NewSource(time.Now().UnixNano())) + r.Shuffle(len(ctx), func(i, j int) { + ctx[i], ctx[j] = ctx[j], ctx[i] + cancel[i], cancel[j] = cancel[j], cancel[i] + }) + + for i := 0; i < 50; i++ { + select { + case <-pool.Done(): + t.Error("expected context to not be cancelled") + case <-time.After(time.Millisecond): + } + cancel[i]() + } + + select { + case <-pool.Done(): + case <-time.After(time.Second): + t.Error("expected context pool to be cancelled") + } + }) + + t.Run("pool size will not increase if the given contexts have been cancelled", func(t *testing.T) { + t.Parallel() + + ctx1, cancel1 := context.WithCancel(context.Background()) + ctx2, cancel2 := context.WithCancel(context.Background()) + pool := NewPool(ctx1, ctx2) + assert.Equal(t, 2, pool.Size()) + + cancel1() + cancel2() + select { + case <-pool.Done(): + case <-time.After(time.Second): + t.Error("expected context pool to be cancelled") + } + pool.Add(context.Background()) + assert.Equal(t, 2, pool.Size()) + }) + + t.Run("pool size will not increase if the pool has been closed", func(t *testing.T) { + t.Parallel() + + ctx1 := context.Background() + ctx2 := context.Background() + pool := NewPool(ctx1, ctx2) + assert.Equal(t, 2, pool.Size()) + pool.Cancel() + pool.Add(context.Background()) + assert.Equal(t, 0, pool.Size()) + select { + case <-pool.Done(): + case <-time.After(time.Second): + t.Error("expected context pool to be cancelled") + } + }) +}