Skip to content

Commit

Permalink
Use atomic.Value for rate limiter and fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
madhuravi committed Dec 20, 2017
1 parent b607c75 commit 162407f
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 29 deletions.
3 changes: 1 addition & 2 deletions service/matching/matchingEngine_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -600,7 +600,6 @@ func (s *matchingEngineSuite) TestConcurrentPublishConsumeActivitiesWithZeroDisp
errCt := s.concurrentPublishConsumeActivities(workerCount, taskCount, dispatchLimitFn)
// atleast 4 times from 0 dispatch poll, but quite a bit more until TTL is hit and throttle limit
// is reset
fmt.Println("Error count: ", errCt)
s.True(errCt >= 4 && errCt < (workerCount*int(taskCount)))
}

Expand Down Expand Up @@ -699,7 +698,6 @@ func (s *matchingEngineSuite) concurrentPublishConsumeActivities(
defer wg.Done()
for i := int64(0); i < taskCount; {
maxDispatch := dispatchLimitFn(wNum, i)
fmt.Println("Max dispatch: ", maxDispatch)
result, err := s.matchingEngine.PollForActivityTask(s.callContext, &matching.PollForActivityTaskRequest{
DomainUUID: common.StringPtr(domainID),
PollRequest: &workflow.PollForActivityTaskRequest{
Expand Down Expand Up @@ -1324,6 +1322,7 @@ func (s *matchingEngineSuite) TestTaskListManagerGetTaskBatch() {
})}
}, nil)

time.Sleep(time.Second)
// complete rangeSize events
for i := int64(0); i < rangeSize; i++ {
result, err := s.matchingEngine.PollForActivityTask(s.callContext, &matching.PollForActivityTaskRequest{
Expand Down
62 changes: 35 additions & 27 deletions service/matching/taskListManager.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,31 +64,35 @@ type taskListManager interface {
type rateLimiter struct {
sync.RWMutex
maxDispatchPerSecond *float64
limiter *rate.Limiter
globalLimiter atomic.Value
// TTL is used to determine whether to update the limit. Until TTL, pick
// lower(existing TTL, input TTL). After TTL, pick input TTL if different from existing TTL
ttlTimer *time.Timer
ttl time.Duration
}

func newRateLimiter(maxDispatchPerSecond *float64, ttl time.Duration) rateLimiter {
return rateLimiter{
rl := rateLimiter{
maxDispatchPerSecond: maxDispatchPerSecond,
// Note: Potentially expose burst config in future
limiter: rate.NewLimiter(
rate.Limit(*maxDispatchPerSecond), int(*maxDispatchPerSecond),
),
ttl: ttl,
ttlTimer: time.NewTimer(ttl),
ttl: ttl,
ttlTimer: time.NewTimer(ttl),
}
// Note: Potentially expose burst config in future
limiter := rate.NewLimiter(
rate.Limit(*maxDispatchPerSecond), int(*maxDispatchPerSecond),
)
rl.globalLimiter.Store(limiter)
return rl
}

func (rl *rateLimiter) UpdateMaxDispatch(maxDispatchPerSecond *float64) {
if rl.shouldUpdate(maxDispatchPerSecond) {
rl.Lock()
defer rl.Unlock()
rl.maxDispatchPerSecond = maxDispatchPerSecond
rl.limiter = rate.NewLimiter(rate.Limit(*maxDispatchPerSecond), int(*maxDispatchPerSecond))
rl.Unlock()
rl.globalLimiter.Store(
rate.NewLimiter(rate.Limit(*maxDispatchPerSecond), int(*maxDispatchPerSecond)),
)
}
}

Expand All @@ -110,15 +114,13 @@ func (rl *rateLimiter) shouldUpdate(maxDispatchPerSecond *float64) bool {
}

func (rl *rateLimiter) Wait(ctx context.Context) error {
rl.RLock()
defer rl.RUnlock()
return rl.limiter.Wait(ctx)
limiter := rl.globalLimiter.Load().(*rate.Limiter)
return limiter.Wait(ctx)
}

func (rl *rateLimiter) Reserve() *rate.Reservation {
rl.RLock()
defer rl.RUnlock()
return rl.limiter.Reserve()
limiter := rl.globalLimiter.Load().(*rate.Limiter)
return limiter.Reserve()
}

func newTaskListManager(
Expand All @@ -139,11 +141,12 @@ func newTaskListManagerWithRateLimiter(
// To perform one db operation if there are no pollers
taskBufferSize := config.GetTasksBatchSize - 1
tlMgr := &taskListManagerImpl{
engine: e,
taskBuffer: make(chan *persistence.TaskInfo, taskBufferSize),
notifyCh: make(chan struct{}, 1),
shutdownCh: make(chan struct{}),
taskListID: taskList,
engine: e,
taskBuffer: make(chan *persistence.TaskInfo, taskBufferSize),
notifyCh: make(chan struct{}, 1),
shutdownCh: make(chan struct{}),
deliverBufferShutdownCh: make(chan struct{}),
taskListID: taskList,
logger: e.logger.WithFields(bark.Fields{
logging.TagTaskListType: taskList.taskType,
logging.TagTaskListName: taskList.taskListName,
Expand Down Expand Up @@ -193,10 +196,13 @@ type taskListManagerImpl struct {
// only if there is waiting poll that consumes from it. Tasks in taskBuffer will blocking-add to
// this channel
tasksForPoll chan *getTaskResult
notifyCh chan struct{} // Used as signal to notify pump of new tasks
shutdownCh chan struct{} // Delivers stop to the pump that populates taskBuffer
startWG sync.WaitGroup // ensures that background processes do not start until setup is ready
stopped int32
notifyCh chan struct{} // Used as signal to notify pump of new tasks
// Note: We need two shutdown channels so we can stop task pump independently of the deliverBuffer
// loop in getTasksPump in unit tests
shutdownCh chan struct{} // Delivers stop to the pump that populates taskBuffer
deliverBufferShutdownCh chan struct{} // Delivers stop to the pump that populates taskBuffer
startWG sync.WaitGroup // ensures that background processes do not start until setup is ready
stopped int32

sync.Mutex
taskAckManager ackManager // tracks ackLevel for delivered messages
Expand Down Expand Up @@ -255,6 +261,7 @@ func (c *taskListManagerImpl) Stop() {
return
}
close(c.shutdownCh)
close(c.deliverBufferShutdownCh)
c.taskWriter.Stop()
c.engine.removeTaskListManager(c.taskListID)
logging.LogTaskListUnloadedEvent(c.logger)
Expand Down Expand Up @@ -596,12 +603,13 @@ func (c *taskListManagerImpl) getTasksPump() {
defer wg.Done()
select {
case <-done:
case <-c.shutdownCh:
case <-c.deliverBufferShutdownCh:
}
cancel()
}()
err := c.rateLimiter.Wait(ctx)
if err != nil {
c.logger.Warn("Unable to send tasks for poll, limit exceeded")
done <- struct{}{}
wg.Wait()
c.metricsClient.IncCounter(scope, metrics.AddThrottleCounter)
Expand All @@ -615,7 +623,7 @@ func (c *taskListManagerImpl) getTasksPump() {
break deliverBufferTasksLoop
}
c.tasksForPoll <- &getTaskResult{task: task}
case <-c.shutdownCh:
case <-c.deliverBufferShutdownCh:
break deliverBufferTasksLoop
}
}
Expand Down

0 comments on commit 162407f

Please sign in to comment.