diff --git a/concurrency/session_manager.go b/concurrency/session_manager.go index d241860..fe1d4c2 100644 --- a/concurrency/session_manager.go +++ b/concurrency/session_manager.go @@ -2,7 +2,6 @@ package concurrency import ( "context" - "fmt" "sync" "time" @@ -15,18 +14,11 @@ const ( ) type SessionManager struct { - // These fields must not be accessed by more than one - // goroutine. - // Session singleton that is refreshed if it closes. - session *Session - // Channel used by the session to communicate that it is closed. - sessionDone <-chan struct{} + // Session singleton that is cleared if it closes. + session *Session + sessionMutex sync.Mutex logger *zap.Logger - retryDelay time.Duration - get chan sessionManagerGetRequest - close chan struct{} - closeOnce sync.Once newSession func() (*Session, error) } @@ -38,12 +30,8 @@ func NewSessionManager(client *clientv3.Client, logger *zap.Logger) *SessionMana func newSessionManager(client *clientv3.Client, retryDelay time.Duration, logger *zap.Logger) *SessionManager { sm := &SessionManager{ logger: logger, - retryDelay: retryDelay, - get: make(chan sessionManagerGetRequest), - close: make(chan struct{}), newSession: func() (*Session, error) { return NewSession(client) }, } - go sm.run() return sm } @@ -51,105 +39,23 @@ func newSessionManager(client *clientv3.Client, retryDelay time.Duration, logger // cannot be obtained. The context needs to have a timeout, otherwise it // is possible for the calling goroutine to hang. func (sm *SessionManager) GetSession(ctx context.Context) (*Session, error) { - request := sessionManagerGetRequest{ - resp: make(chan *Session), - } - go func() { - sm.get <- request - }() - select { - case <-ctx.Done(): - return nil, ctx.Err() - case session := <-request.resp: - return session, nil - } -} - -// Close closes the manager, causing the current session to be closed -// and no new ones to be created. -func (sm *SessionManager) Close() { - sm.closeOnce.Do(func() { - close(sm.close) - }) -} - -func (sm *SessionManager) resetSession() { - sm.logger.Info("Initializing session") - session, err := sm.newSession() - for err != nil { - sm.logger.Error("Error getting session", zap.Error(err)) - stopRetry := false - func() { - ctx, cancel := context.WithTimeout(context.Background(), sm.retryDelay) - defer cancel() - select { - case <-ctx.Done(): - // Let pass so retry can be attempted. - case <-sm.close: - stopRetry = true - } - }() - if stopRetry { - return - } - session, err = sm.newSession() - } - sm.session = session - sm.sessionDone = session.Done() - sm.logger.Info("new session initialized", zap.String("lease_id", fmt.Sprintf("%x", sm.session.Lease()))) -} - -func (sm *SessionManager) run() { - // Thread safety is handled by controlling all activity - // through a single goroutine that interacts with other - // goroutines via channels. - sm.logger.Info("Starting session manager") -run: - for { - // If the session manager should be closed, give - // that the highest priority. - select { - case <-sm.close: - sm.logger.Info("Closing session manager") - if sm.session != nil { - // This may fail the session was already closed - // due to some external cause, like etcd connectivity - // issues. The result is just a log message. - sm.session.Close() - } - break run - default: - } - switch { - case sm.sessionDone == nil: - sm.resetSession() - continue - } - // If the current session has closed, - // prioritize creating a new one ahead - // of remaining concerns. - select { - case <-sm.sessionDone: - // Create new session - sm.resetSession() - continue - default: - } - select { - case <-sm.close: - // Let the check above take care of cleanup - continue - case <-sm.sessionDone: - // Let the check above take care of creating a new session - continue - case req := <-sm.get: - // Get the current session - req.resp <- sm.session + sm.sessionMutex.Lock() + defer sm.sessionMutex.Unlock() + if sm.session == nil { + var err error + sm.session, err = sm.newSession() + if err != nil { + return nil, err } + sessionDone := sm.session.Done() + // Start goroutine to check for closed session. + go func() { + <-sessionDone + // Clear out dead session + sm.sessionMutex.Lock() + defer sm.sessionMutex.Unlock() + sm.session = nil + }() } - sm.logger.Info("Session manager closed") -} - -type sessionManagerGetRequest struct { - resp chan *Session + return sm.session, nil } diff --git a/concurrency/session_manager_test.go b/concurrency/session_manager_test.go index 5582a22..5790a91 100644 --- a/concurrency/session_manager_test.go +++ b/concurrency/session_manager_test.go @@ -2,14 +2,11 @@ package concurrency import ( "context" - "errors" "sync" "testing" - "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "go.etcd.io/etcd/clientv3" "go.uber.org/zap" "github.com/IBM-Cloud/go-etcd-rules/rules/teststore" @@ -17,6 +14,7 @@ import ( func Test_SessionManager(t *testing.T) { _, client := teststore.InitV3Etcd(t) + defer client.Close() lgr, err := zap.NewDevelopment() require.NoError(t, err) mgr := newSessionManager(client, 0, lgr) @@ -41,84 +39,15 @@ func Test_SessionManager(t *testing.T) { wg.Wait() } -func Test_SessionManager_Close(t *testing.T) { - lgr, err := zap.NewDevelopment() - require.NoError(t, err) - _, goodClient := teststore.InitV3Etcd(t) - badClient, _ := clientv3.New(clientv3.Config{ - Endpoints: []string{"http://127.0.0.1:2377"}, - }) - testCases := []struct { - name string - - client *clientv3.Client - newSession func() (*Session, error) - }{ - { - name: "ok", - client: goodClient, - newSession: func() (*Session, error) { - return NewSession(goodClient) - }, - }, - { - name: "bad", - client: badClient, - newSession: func() (*Session, error) { - return nil, errors.New("bad") - }, - }, - } - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - mgr := &SessionManager{ - logger: lgr, - retryDelay: time.Millisecond, - get: make(chan sessionManagerGetRequest), - close: make(chan struct{}), - newSession: tc.newSession, - } - go mgr.run() - var wg sync.WaitGroup - // Use a lot of goroutines to ensure any concurrency - // issues are caught by race condition checks. - for i := 0; i < 1000; i++ { - // Make a copy for the goroutine - localI := i - wg.Add(1) - go func() { - ctx, cancel := context.WithTimeout(context.Background(), time.Second) - defer cancel() - session, _ := mgr.GetSession(ctx) - if localI%10 == 0 { - // Disrupt things by closing sessions, forcing - // the manager to create new ones. - if session != nil { - _ = session.Close() - } - } - if localI%25 == 0 { - mgr.Close() - } - wg.Done() - }() - } - wg.Wait() - }) - } -} - func Test_NewSessionManager(t *testing.T) { _, client := teststore.InitV3Etcd(t) lgr, err := zap.NewDevelopment() require.NoError(t, err) mgr := NewSessionManager(client, lgr) assert.Equal(t, lgr, mgr.logger) - assert.Equal(t, sessionManagerRetryDelay, mgr.retryDelay) - assert.NotNil(t, mgr.get) - assert.NotNil(t, mgr.close) session, err := mgr.newSession() - require.NoError(t, err) - session.Close() - mgr.Close() + if assert.NoError(t, err) { + session.Close() + } + assert.NoError(t, client.Close()) } diff --git a/rules/engine.go b/rules/engine.go index 3294c83..65bfd82 100644 --- a/rules/engine.go +++ b/rules/engine.go @@ -10,6 +10,7 @@ import ( "go.uber.org/zap" "golang.org/x/net/context" + "github.com/IBM-Cloud/go-etcd-rules/concurrency" "github.com/IBM-Cloud/go-etcd-rules/rules/lock" ) @@ -112,8 +113,30 @@ func newV3Engine(logger *zap.Logger, cl *clientv3.Client, options ...EngineOptio logger: logger, } } - baseEtcdLocker := lock.NewV3Locker(cl, opts.lockAcquisitionTimeout) + var baseEtcdLocker lock.RuleLocker + if opts.useSharedLockSession { + sessionManager := concurrency.NewSessionManager(cl, logger) + baseEtcdLocker = lock.NewSessionLocker(sessionManager.GetSession, opts.lockAcquisitionTimeout, false, opts.useTryLock) + } else { + baseEtcdLocker = lock.NewV3Locker(cl, opts.lockAcquisitionTimeout, opts.useTryLock) + } metricsEtcdLocker := lock.WithMetrics(baseEtcdLocker, "etcd") + var baseLocker lock.RuleLocker + if opts.useSharedLockSession { + baseMapLocker := lock.NewMapLocker() + metricsMapLocker := lock.WithMetrics(baseMapLocker, "map") + baseLocker = lock.NewNestedLocker(metricsMapLocker, metricsEtcdLocker) + } else { + baseLocker = metricsEtcdLocker + } + var finalLocker lock.RuleLocker + if opts.lockCoolOff == 0 { + finalLocker = baseLocker + } else { + coolOffLocker := lock.NewCoolOffLocker(opts.lockCoolOff) + metricsCoolOffLocker := lock.WithMetrics(coolOffLocker, "cooloff") + finalLocker = lock.NewNestedLocker(metricsCoolOffLocker, baseLocker) + } eng := v3Engine{ baseEngine: baseEngine{ keyProc: &keyProc, @@ -122,7 +145,7 @@ func newV3Engine(logger *zap.Logger, cl *clientv3.Client, options ...EngineOptio options: opts, ruleLockTTLs: map[int]int{}, ruleMgr: ruleMgr, - locker: metricsEtcdLocker, + locker: finalLocker, callbackListener: cbListener, }, keyProc: keyProc, diff --git a/rules/lock/cooloff_lock.go b/rules/lock/cooloff_lock.go new file mode 100644 index 0000000..5fbe2b1 --- /dev/null +++ b/rules/lock/cooloff_lock.go @@ -0,0 +1,69 @@ +package lock + +import ( + "fmt" + "sync" + "time" +) + +const ( + coolOffErrFormat = "cooloff expires in %s" +) + +// NewCoolOffLocker creates a simple locker that will prevent a lock from +// being obtained if a previous attempt (successful or not) was made within +// the specified expiration period. It is intended to be used with other lockers +// to prevent excessive locking using more expensive resources (e.g. etcd). It is +// theoretically possible for two callers to obtain the same lock, if the cooloff +// period expires before the first caller releases the lock; therefore this locker +// needs to be used with a nested locker to prevent two callers from accessing the +// same protected resource. +func NewCoolOffLocker(expiration time.Duration) RuleLocker { + locker := coolOffLocker{ + coolOffDuration: expiration, + locks: make(map[string]time.Time), + mutex: &sync.Mutex{}, + } + return locker +} + +type coolOffLocker struct { + locks map[string]time.Time + mutex *sync.Mutex + coolOffDuration time.Duration +} + +func (col coolOffLocker) Lock(key string, options ...Option) (RuleLock, error) { + col.mutex.Lock() + defer col.mutex.Unlock() + now := time.Now() + // Remove any expired keys + var toDelete []string + for k, v := range col.locks { + if now.After(v) { + toDelete = append(toDelete, k) + } + } + for _, key := range toDelete { + delete(col.locks, key) + } + var err error + if _, ok := col.locks[key]; ok { + err = fmt.Errorf(coolOffErrFormat, col.coolOffDuration) + } + // Failed attempts to get the lock should also update the cooloff, + // so always add the key regardless of success or failure. + col.locks[key] = now.Add(col.coolOffDuration) + + if err != nil { + return nil, err + } + return coolOffLock{}, nil +} + +type coolOffLock struct { +} + +func (coolOffLock) Unlock() error { + return nil +} diff --git a/rules/lock/cooloff_lock_test.go b/rules/lock/cooloff_lock_test.go new file mode 100644 index 0000000..35330e7 --- /dev/null +++ b/rules/lock/cooloff_lock_test.go @@ -0,0 +1,94 @@ +package lock + +import ( + "fmt" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func Test_coolOffLocker(t *testing.T) { + const ( + timeout = time.Millisecond * 100 + key1 = "key1" + key2 = "key2" + ) + // Test cases will try to obtain two locks. + // The duration between requests is controlled + // via the "delay" field and whether or not the + // second attempt uses the same key is controlled + // via the "keyDifferent" field. + testCases := []struct { + name string + + delay time.Duration + keyDifferent bool + + err bool + }{ + { + name: "ok_same_key_enough_delay", + delay: timeout * 2, + keyDifferent: false, + }, + { + name: "ok_different_key_no_delay", + delay: 0, + keyDifferent: true, + }, + { + name: "fail_same_key_insufficient_delay", + delay: 0, + keyDifferent: false, + err: true, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + var err error + locker := NewCoolOffLocker(timeout) + // Obtain the first lock + lock1, err := locker.Lock(key1) + // This should always be successful + require.NoError(t, err) + require.NotNil(t, lock1) + // Make sure that the lock works + defer require.NotPanics(t, func() { assert.NoError(t, lock1.Unlock()) }) + + // Wait some period of time + time.Sleep(tc.delay) + + var secondLockKey string + if tc.keyDifferent { + secondLockKey = key2 + } else { + secondLockKey = key1 + } + lock2, err := locker.Lock(secondLockKey) + + if tc.err { + if assert.Error(t, err) { + // Verify that the error message uses the correct format. + var ( + timeoutString string + valueCount int + ) + valueCount, err = fmt.Fscanf(strings.NewReader(err.Error()), coolOffErrFormat, &timeoutString) + assert.NoError(t, err) + assert.Equal(t, 1, valueCount) + assert.NotEmpty(t, timeoutString) + } + assert.Nil(t, lock2) + return + } + assert.NoError(t, err) + assert.NotNil(t, lock2) + // Make sure the second lock works + require.NotPanics(t, func() { assert.NoError(t, lock2.Unlock()) }) + }) + } + +} diff --git a/rules/lock/lock.go b/rules/lock/lock.go index 0526093..476bb41 100644 --- a/rules/lock/lock.go +++ b/rules/lock/lock.go @@ -18,44 +18,67 @@ type RuleLock interface { Unlock() error } +type GetSession func(context.Context) (*concurrency.Session, error) + // NewV3Locker creates a locker backed by etcd V3. -func NewV3Locker(cl *clientv3.Client, lockTimeout int) RuleLocker { +func NewV3Locker(cl *clientv3.Client, lockTimeout int, useTryLock bool) RuleLocker { + // The TTL is for the lease associated with the session, in seconds. While the session is still open, + // the lease's TTL will keep getting renewed to keep it from expiring, so all this really does is + // set the amount of time it takes for the lease to expire if the lease stops being renewed due + // to the application shutting down before a session could be properly closed. + newSession := func(_ context.Context) (*concurrency.Session, error) { + return concurrency.NewSession(cl, concurrency.WithTTL(30)) + } + return NewSessionLocker(newSession, lockTimeout, true, useTryLock) +} + +// NewSessionLocker creates a new locker with the provided session constructor. Note that +// if closeSession is false, it means that the session provided by getSession will not be +// closed but instead be reused. In that case the locker must be protected by another locker +// (for instance an in-memory locker) because locks within the same session are reentrant +// which means that two goroutines can obtain the same lock. +func NewSessionLocker(getSession GetSession, lockTimeout int, closeSession, useTryLock bool) RuleLocker { return &v3Locker{ - cl: cl, - lockTimeout: lockTimeout, + lockTimeout: lockTimeout, + newSession: getSession, + closeSession: closeSession, + useTryLock: useTryLock, } } type v3Locker struct { - cl *clientv3.Client - lockTimeout int + lockTimeout int + newSession GetSession + closeSession bool + useTryLock bool } func (v3l *v3Locker) Lock(key string, options ...Option) (RuleLock, error) { return v3l.lockWithTimeout(key, v3l.lockTimeout) } func (v3l *v3Locker) lockWithTimeout(key string, timeout int) (RuleLock, error) { - // TODO once we switch to a shared session, we can get rid of the TTL option - // and go to the default (60 seconds). This is the TTL for the lease that - // is associated with the session and the lease is renewed before it expires - // while the session is active (not closed). It is not the TTL of any locks; - // those persist until Unlock is called or the process dies and the session - // lease is allowed to expire. - s, err := concurrency.NewSession(v3l.cl, concurrency.WithTTL(30)) + ctx, cancel := context.WithTimeout(context.Background(), time.Duration(timeout)*time.Second) + defer cancel() + s, err := v3l.newSession(ctx) if err != nil { return nil, err } m := concurrency.NewMutex(s, key) - ctx, cancel := context.WithTimeout(context.Background(), time.Duration(timeout)*time.Second) - defer cancel() - err = m.Lock(ctx) + if v3l.useTryLock { + err = m.TryLock(ctx) + } else { + err = m.Lock(ctx) + } if err != nil { return nil, err } - return &v3Lock{ - mutex: m, - session: s, - }, nil + lock := &v3Lock{ + mutex: m, + } + if v3l.closeSession { + lock.session = s + } + return lock, nil } type v3Lock struct { diff --git a/rules/lock/lock_test.go b/rules/lock/lock_test.go index 156cbc2..7dd46f5 100644 --- a/rules/lock/lock_test.go +++ b/rules/lock/lock_test.go @@ -4,40 +4,57 @@ import ( "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "go.etcd.io/etcd/clientv3" + "golang.org/x/net/context" + "github.com/IBM-Cloud/go-etcd-rules/concurrency" "github.com/IBM-Cloud/go-etcd-rules/rules/teststore" ) func Test_V3Locker(t *testing.T) { cfg, cl := teststore.InitV3Etcd(t) c, err := clientv3.New(cfg) - assert.NoError(t, err) - rlckr := v3Locker{ - cl: cl, - lockTimeout: 5, + require.NoError(t, err) + newSession := func(_ context.Context) (*concurrency.Session, error) { + return concurrency.NewSession(cl, concurrency.WithTTL(30)) } - rlck, err1 := rlckr.Lock("test") - assert.NoError(t, err1) - _, err2 := rlckr.lockWithTimeout("test", 10) - assert.Error(t, err2) - assert.NoError(t, rlck.Unlock()) - done1 := make(chan bool) - done2 := make(chan bool) - - go func() { - lckr := NewV3Locker(c, 5) - lck, lErr := lckr.Lock("test1") - assert.NoError(t, lErr) - done1 <- true - <-done2 - if lck != nil { - assert.NoError(t, lck.Unlock()) + for _, useTryLock := range []bool{false, true} { + var name string + if useTryLock { + name = "use_try_lock" + } else { + name = "use_lock" } - }() - <-done1 - _, err = rlckr.Lock("test1") - assert.Error(t, err) - done2 <- true + t.Run(name, func(t *testing.T) { + rlckr := v3Locker{ + newSession: newSession, + lockTimeout: 5, + } + rlck, err1 := rlckr.Lock("test") + assert.NoError(t, err1) + _, err2 := rlckr.lockWithTimeout("test", 10) + assert.Error(t, err2) + assert.NoError(t, rlck.Unlock()) + + done1 := make(chan bool) + done2 := make(chan bool) + + go func() { + lckr := NewV3Locker(c, 5, useTryLock) + lck, lErr := lckr.Lock("test1") + assert.NoError(t, lErr) + done1 <- true + <-done2 + if lck != nil { + assert.NoError(t, lck.Unlock()) + } + }() + <-done1 + _, err = rlckr.Lock("test1") + assert.Error(t, err) + done2 <- true + }) + } } diff --git a/rules/lock/map_locker.go b/rules/lock/map_locker.go new file mode 100644 index 0000000..4b691f0 --- /dev/null +++ b/rules/lock/map_locker.go @@ -0,0 +1,78 @@ +package lock + +import ( + "errors" + "sync" +) + +func NewMapLocker() RuleLocker { + ml := newMapLocker() + // Using the adapter to reduce the number of critical sections down to + // 1, lessening the chances of concurrency issues being introduced. + return toggleLockerAdapter{ + toggle: ml.toggle, + errLocked: ErrLockedLocally, + } +} + +type mapLocker struct { + mutex *sync.Mutex + m map[string]bool +} + +func newMapLocker() mapLocker { + return mapLocker{ + m: make(map[string]bool), + mutex: &sync.Mutex{}, + } +} + +func (ml mapLocker) toggle(key string, lock bool) bool { + ml.mutex.Lock() + defer ml.mutex.Unlock() + // 4 possibilities: + // 1. key is locked and lock is true: return false + // 2. key is locked and lock is false: unlock key and return true + if ml.m[key] { + if !lock { + delete(ml.m, key) + } + return !lock + } + // 3. key is unlocked and lock is true: lock key and return true + // 4. key is unlocked and lock is false: return true + if lock { + ml.m[key] = true + } + return true +} + +// ErrLockedLocally indicates that a local goroutine holds the lock +// and no attempt will be made to obtain the lock via etcd. +var ErrLockedLocally = errors.New("locked locally") + +type toggleLockerAdapter struct { + toggle func(key string, lock bool) bool + errLocked error +} + +func (tla toggleLockerAdapter) Lock(key string, options ...Option) (RuleLock, error) { + ok := tla.toggle(key, true) + if !ok { + return nil, tla.errLocked + } + return toggleLock{ + toggle: tla.toggle, + key: key, + }, nil +} + +type toggleLock struct { + toggle func(key string, lock bool) bool + key string +} + +func (tl toggleLock) Unlock() error { + _ = tl.toggle(tl.key, false) + return nil +} diff --git a/rules/lock/map_locker_test.go b/rules/lock/map_locker_test.go new file mode 100644 index 0000000..86b229c --- /dev/null +++ b/rules/lock/map_locker_test.go @@ -0,0 +1,116 @@ +package lock + +import ( + "errors" + "testing" + + "github.com/stretchr/testify/assert" +) + +func Test_mapLocker_toggle(t *testing.T) { + testCases := []struct { + name string + + setup func(ml *mapLocker) + + key string + lock bool + + ok bool + }{ + { + name: "get_available", + key: "/foo", + setup: func(ml *mapLocker) { + ml.toggle("/bar", true) + }, + lock: true, + ok: true, + }, + { + name: "get_unavailable", + key: "/foo", + setup: func(ml *mapLocker) { + ml.toggle("/foo", true) + }, + lock: true, + ok: false, + }, + { + name: "release_existing", + key: "/foo", + setup: func(ml *mapLocker) { + ml.toggle("/foo", true) + }, + lock: false, + ok: true, + }, + { + name: "release_nonexistent", + key: "/foo", + lock: false, + ok: true, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + ml := newMapLocker() + // defer ml.close() + + if tc.setup != nil { + tc.setup(&ml) + } + + assert.Equal(t, tc.ok, ml.toggle(tc.key, tc.lock)) + + }) + } +} + +func Test_toggleLockAdapter(t *testing.T) { + const ( + testKey = "/foo" + ) + errLocked := errors.New("locked") + testCases := []struct { + name string + + lock bool + toggleOk bool + + err error + }{ + { + name: "success", + toggleOk: true, + }, + { + name: "failure", + toggleOk: false, + err: errLocked, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + expectedLock := true + var err error + tla := toggleLockerAdapter{ + toggle: func(key string, lock bool) bool { + assert.Equal(t, expectedLock, lock) + assert.Equal(t, testKey, key) + return tc.toggleOk + }, + errLocked: errLocked, + } + var _ RuleLocker = tla + lock, err := tla.Lock(testKey) + if tc.err != nil { + assert.EqualError(t, err, tc.err.Error()) + return + } + assert.NoError(t, err) + expectedLock = false + _ = assert.NotNil(t, lock) && assert.NoError(t, lock.Unlock()) + }) + } +} diff --git a/rules/lock/nested_lock.go b/rules/lock/nested_lock.go new file mode 100644 index 0000000..430f70a --- /dev/null +++ b/rules/lock/nested_lock.go @@ -0,0 +1,61 @@ +package lock + +// NewNestedLocker creates a locker that protects the inner +// locker with an outer locker, so that no unnecessary calls +// are made to the inner locker when attempting to obtain +// an unavailable lock. +func NewNestedLocker(outer, inner RuleLocker) RuleLocker { + return nestedLocker{ + own: outer, + nested: inner, + } +} + +type nestedLocker struct { + own RuleLocker + nested RuleLocker +} + +func (nl nestedLocker) Lock(key string, options ...Option) (RuleLock, error) { + // Try to obtain own lock first, preempting attempts + // to obtain the nested (more expensive) lock if + // getting the local lock fails. + lock, err := nl.own.Lock(key, options...) + if err != nil { + return nil, err + } + // Try to obtain the nested lock + nested, err := nl.nested.Lock(key, options...) + if err != nil { + // First unlock own lock + _ = lock.Unlock() + return nil, err + } + return nestedLock{ + own: lock, + nested: nested, + }, nil +} + +type nestedLock struct { + own RuleLock + nested RuleLock +} + +func (nl nestedLock) Unlock() error { + // Always unlock own lock, but after + // nested lock. This prevents attempting + // to get a new instance of the nested lock + // before the own lock is cleared. If the nested + // lock persists due to an error, it should be + // cleared with separate logic. + + err := nl.nested.Unlock() + ownError := nl.own.Unlock() + // The nested lock is assumed to be more expensive so + // its error takes precedence. + if err == nil { + err = ownError + } + return err +} diff --git a/rules/lock/nested_lock_test.go b/rules/lock/nested_lock_test.go new file mode 100644 index 0000000..b5588a5 --- /dev/null +++ b/rules/lock/nested_lock_test.go @@ -0,0 +1,113 @@ +package lock + +import ( + "errors" + "testing" + + "github.com/stretchr/testify/assert" +) + +func Test_nestedLocker_Lock(t *testing.T) { + // Set up mock data for mock functions + type testLock struct { + RuleLock + val string // Just something to compare. + } + var ownUnlockCalled bool + testOwnLock := testLock{ + RuleLock: FuncMockLock{ + UnlockF: func() error { + ownUnlockCalled = true + return nil + }, + }, + val: "own", + } + testNestedLock := testLock{ + val: "nested", + } + + ownLockErr := errors.New("own lock") + nestedLockErr := errors.New("nested lock") + + testCases := []struct { + name string + + nestedCalled bool + ownUnlockCalled bool + + err error + ownLockErr error + nestedLockErr error + }{ + { + name: "ok", + nestedCalled: true, + }, + { + name: "own_error", + ownLockErr: ownLockErr, + err: ownLockErr, + }, + { + name: "nested_error", + nestedCalled: true, + ownUnlockCalled: true, + nestedLockErr: nestedLockErr, + err: nestedLockErr, + }, + { + name: "both_errors", + ownLockErr: ownLockErr, + nestedLockErr: nestedLockErr, + err: ownLockErr, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + // Reset from any previous runs + ownUnlockCalled = false + ownCalled := false + nestedCalled := false + nl := nestedLocker{ + own: FuncMockLocker{ + LockF: func(key string, options ...Option) (RuleLock, error) { + assert.Equal(t, "key", key) + ownCalled = true + return testOwnLock, tc.ownLockErr + }, + }, + nested: FuncMockLocker{ + LockF: func(key string, options ...Option) (RuleLock, error) { + // The own locker should have been called first + assert.True(t, ownCalled) + assert.Equal(t, "key", key) + nestedCalled = true + return testNestedLock, tc.nestedLockErr + }, + }, + } + var err error + lock, err := nl.Lock("key") + assert.Equal(t, tc.nestedCalled, nestedCalled) + assert.Equal(t, tc.ownUnlockCalled, ownUnlockCalled) + if tc.err != nil { + assert.EqualError(t, err, tc.err.Error()) + return + } + assert.NoError(t, err) + nLock, ok := lock.(nestedLock) + if assert.True(t, ok) { + getVal := func(rl RuleLock) string { + tl, ok := rl.(testLock) + if !ok { + return "" + } + return tl.val + } + assert.Equal(t, testOwnLock.val, getVal(nLock.own)) + assert.Equal(t, testNestedLock.val, getVal(nLock.nested)) + } + }) + } +} diff --git a/rules/options.go b/rules/options.go index 1797d2d..229b5fd 100644 --- a/rules/options.go +++ b/rules/options.go @@ -63,6 +63,9 @@ type engineOptions struct { ruleWorkBuffer int enhancedRuleFilter bool metrics MetricsCollectorOpt + lockCoolOff time.Duration + useSharedLockSession bool + useTryLock bool } func makeEngineOptions(options ...EngineOption) engineOptions { @@ -181,6 +184,33 @@ func KeyConstraint(attribute string, prefix string, chars [][]rune) EngineOption }) } +// EngineUseTryLock is an experimental option to fail locking immediately when a lock +// is already held as opposed to trying to obtain the lock until the timeout expires +func EngineUseTryLock() EngineOption { + return engineOptionFunction(func(o *engineOptions) { + o.useTryLock = true + }) +} + +// EngineUseSharedLockSession is an experimental option to use a single concurrency +// session for managing locks to reduce the ETCD load by eliminating the need to +// create new concurrency session for each locking attempt. +func EngineUseSharedLockSession() EngineOption { + return engineOptionFunction(func(o *engineOptions) { + o.useSharedLockSession = true + }) +} + +// EngineLockCoolOff is an experimental option to preemptively fail locking attempts +// if an attempt to obtain the same lock was made within the specified duration so +// that multiple workers reacting to multiple elements of the same rule and attributes +// do not cause needless locking. +func EngineLockCoolOff(timeout time.Duration) EngineOption { + return engineOptionFunction(func(o *engineOptions) { + o.lockCoolOff = timeout + }) +} + // EngineSyncInterval enables the interval between sync or crawler runs to be configured. // The interval is in seconds. func EngineSyncInterval(interval int) EngineOption {