diff --git a/pkg/notification/manager.go b/pkg/notification/manager.go index b95fef810..5b4aa6a14 100644 --- a/pkg/notification/manager.go +++ b/pkg/notification/manager.go @@ -1,5 +1,5 @@ /**************************************************************************** - * Copyright 2019, Optimizely, Inc. and contributors * + * Copyright 2019-2020, Optimizely, Inc. and contributors * * * * Licensed under the Apache License, Version 2.0 (the "License"); * * you may not use this file except in compliance with the License. * @@ -19,6 +19,7 @@ package notification import ( "fmt" + "sync" "sync/atomic" "github.com/optimizely/go-sdk/pkg/logging" @@ -37,6 +38,7 @@ type Manager interface { type AtomicManager struct { handlers map[uint32]func(interface{}) counter uint32 + lock sync.RWMutex } // NewAtomicManager creates a new instance of the atomic manager @@ -48,6 +50,9 @@ func NewAtomicManager() *AtomicManager { // Add adds the given handler func (am *AtomicManager) Add(newHandler func(interface{})) (int, error) { + am.lock.Lock() + defer am.lock.Unlock() + atomic.AddUint32(&am.counter, 1) am.handlers[am.counter] = newHandler return int(am.counter), nil @@ -55,6 +60,9 @@ func (am *AtomicManager) Add(newHandler func(interface{})) (int, error) { // Remove removes handler with the given id func (am *AtomicManager) Remove(id int) { + am.lock.Lock() + defer am.lock.Unlock() + handlerID := uint32(id) if _, ok := am.handlers[handlerID]; ok { delete(am.handlers, handlerID) @@ -66,7 +74,19 @@ func (am *AtomicManager) Remove(id int) { // Send sends the notification to the registered handlers func (am *AtomicManager) Send(notification interface{}) { - for _, handler := range am.handlers { + // copying handler to avoid race condition + handlers := am.copyHandlers() + for _, handler := range handlers { handler(notification) } } + +// Return a copy of the given handlers +func (am *AtomicManager) copyHandlers() (handlers []func(interface{})) { + am.lock.RLock() + defer am.lock.RUnlock() + for _, v := range am.handlers { + handlers = append(handlers, v) + } + return handlers +} diff --git a/pkg/notification/manager_test.go b/pkg/notification/manager_test.go index 23a95f0dc..459e9b1c7 100644 --- a/pkg/notification/manager_test.go +++ b/pkg/notification/manager_test.go @@ -50,3 +50,56 @@ func TestAtomicManager(t *testing.T) { // Sanity check by calling remove with a incorrect handler id atomicManager.Remove(55) } + +func TestSendRaceCondition(t *testing.T) { + sync := make(chan interface{}) + payload := map[string]interface{}{ + "key": "test", + } + atomicManager := NewAtomicManager() + result1, result2 := 0, 0 + listenerCalled := false + + listener1 := func(interface{}) { + } + + listener2 := func(interface{}) { + // Add listener2 internally to assert deadlock + result2, _ = atomicManager.Add(listener1) + // Remove all added listeners + atomicManager.Remove(result1) + atomicManager.Remove(result2) + listenerCalled = true + } + result1, _ = atomicManager.Add(listener2) + + go func() { + atomicManager.Send(payload) + // notifying that notification is sent. + sync <- "" + }() + + atomicManager.Add(listener1) + <-sync + + assert.Equal(t, 1, result1) + assert.Equal(t, len(atomicManager.handlers), 1) + assert.Equal(t, true, listenerCalled) +} + +func TestAddRaceCondition(t *testing.T) { + sync := make(chan interface{}) + atomicManager := NewAtomicManager() + + listener1 := func(interface{}) { + + } + result1, _ := atomicManager.Add(listener1) + go func() { + atomicManager.Remove(result1) + sync <- "" + }() + + <-sync + assert.Equal(t, len(atomicManager.handlers), 0) +}