Skip to content

Commit

Permalink
Update weighted load balancer
Browse files Browse the repository at this point in the history
  • Loading branch information
Shaddoll committed Nov 25, 2024
1 parent 5b2be37 commit 97068e7
Show file tree
Hide file tree
Showing 21 changed files with 879 additions and 392 deletions.
665 changes: 494 additions & 171 deletions .gen/proto/matching/v1/service.pb.go

Large diffs are not rendered by default.

293 changes: 149 additions & 144 deletions .gen/proto/matching/v1/service.pb.yarpc.go

Large diffs are not rendered by default.

11 changes: 9 additions & 2 deletions client/matching/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,6 @@ func (c *clientImpl) PollForActivityTask(
if err != nil {
return nil, err
}
// TODO: update activity response to include backlog count hint and update the weight for partitions
resp, err := c.client.PollForActivityTask(ctx, request, append(opts, yarpc.WithShardKey(peer))...)
if err != nil {
return nil, err
Expand All @@ -145,6 +144,14 @@ func (c *clientImpl) PollForActivityTask(
persistence.TaskListTypeActivity,
resp.PartitionConfig,
)
c.loadBalancer.UpdateWeight(
request.GetDomainUUID(),
*request.PollRequest.GetTaskList(),
persistence.TaskListTypeActivity,
request.GetForwardedFrom(),
partition,
resp.LoadBalancerHints,
)
return resp, nil
}

Expand Down Expand Up @@ -182,7 +189,7 @@ func (c *clientImpl) PollForDecisionTask(
persistence.TaskListTypeDecision,
request.GetForwardedFrom(),
partition,
resp.BacklogCountHint,
resp.LoadBalancerHints,
)
return resp, nil
}
Expand Down
3 changes: 2 additions & 1 deletion client/matching/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,7 @@ func TestClient_withResponse(t *testing.T) {
p.EXPECT().FromTaskList(_testPartition).Return("peer0", nil)
c.EXPECT().PollForActivityTask(gomock.Any(), gomock.Any(), []yarpc.CallOption{yarpc.WithShardKey("peer0")}).Return(&types.MatchingPollForActivityTaskResponse{}, nil)
mp.EXPECT().UpdatePartitionConfig(_testDomainUUID, types.TaskList{Name: _testTaskList}, persistence.TaskListTypeActivity, nil)
balancer.EXPECT().UpdateWeight(_testDomainUUID, types.TaskList{Name: _testTaskList}, persistence.TaskListTypeActivity, "", _testPartition, nil)
},
want: &types.MatchingPollForActivityTaskResponse{},
},
Expand Down Expand Up @@ -274,7 +275,7 @@ func TestClient_withResponse(t *testing.T) {
p.EXPECT().FromTaskList(_testPartition).Return("peer0", nil)
c.EXPECT().PollForDecisionTask(gomock.Any(), gomock.Any(), []yarpc.CallOption{yarpc.WithShardKey("peer0")}).Return(&types.MatchingPollForDecisionTaskResponse{}, nil)
mp.EXPECT().UpdatePartitionConfig(_testDomainUUID, types.TaskList{Name: _testTaskList}, persistence.TaskListTypeDecision, nil)
balancer.EXPECT().UpdateWeight(_testDomainUUID, types.TaskList{Name: _testTaskList}, persistence.TaskListTypeDecision, "", _testPartition, int64(0))
balancer.EXPECT().UpdateWeight(_testDomainUUID, types.TaskList{Name: _testTaskList}, persistence.TaskListTypeDecision, "", _testPartition, nil)
},
want: &types.MatchingPollForDecisionTaskResponse{},
},
Expand Down
4 changes: 2 additions & 2 deletions client/matching/loadbalancer.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ type (
taskListType int,
forwardedFrom string,
partition string,
weight int64,
info *types.LoadBalancerHints,
)
}

Expand Down Expand Up @@ -138,7 +138,7 @@ func (lb *defaultLoadBalancer) UpdateWeight(
taskListType int,
forwardedFrom string,
partition string,
weight int64,
info *types.LoadBalancerHints,
) {
}

Expand Down
8 changes: 4 additions & 4 deletions client/matching/loadbalancer_mock.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion client/matching/loadbalancer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ func Test_defaultLoadBalancer_UpdateWeight(t *testing.T) {
taskList := types.TaskList{Name: "test-task-list", Kind: types.TaskListKindNormal.Ptr()}

// Call UpdateWeight, should do nothing
loadBalancer.UpdateWeight("test-domain-id", taskList, 0, "", "partition", 10)
loadBalancer.UpdateWeight("test-domain-id", taskList, 0, "", "partition", nil)

// No expectations, just ensure no-op
})
Expand Down
4 changes: 2 additions & 2 deletions client/matching/multi_loadbalancer.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ func (lb *multiLoadBalancer) UpdateWeight(
taskListType int,
forwardedFrom string,
partition string,
weight int64,
info *types.LoadBalancerHints,
) {
domainName, err := lb.domainIDToName(domainID)
if err != nil {
Expand All @@ -111,5 +111,5 @@ func (lb *multiLoadBalancer) UpdateWeight(
lb.logger.Warn("unsupported load balancer strategy", tag.Value(strategy))
return
}
loadBalancer.UpdateWeight(domainID, taskList, taskListType, forwardedFrom, partition, weight)
loadBalancer.UpdateWeight(domainID, taskList, taskListType, forwardedFrom, partition, info)
}
57 changes: 33 additions & 24 deletions client/matching/multi_loadbalancer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -240,40 +240,49 @@ func TestMultiLoadBalancer_UpdateWeight(t *testing.T) {
taskListType int
forwardedFrom string
partition string
weight int64
loadBalancerHints *types.LoadBalancerHints
loadbalancerStrategy string
shouldUpdate bool
}{
{
name: "do nothing when domainIDToName fails",
domainID: "invalid-domain",
taskList: types.TaskList{Name: "test-tasklist"},
taskListType: 1,
forwardedFrom: "",
partition: "partition-1",
weight: 10,
name: "do nothing when domainIDToName fails",
domainID: "invalid-domain",
taskList: types.TaskList{Name: "test-tasklist"},
taskListType: 1,
forwardedFrom: "",
partition: "partition-1",
loadBalancerHints: &types.LoadBalancerHints{
BacklogCount: 10,
RatePerSecond: 1,
},
loadbalancerStrategy: "random",
shouldUpdate: false,
},
{
name: "update weight with round-robin load balancer",
domainID: "valid-domain",
taskList: types.TaskList{Name: "test-tasklist"},
taskListType: 1,
forwardedFrom: "",
partition: "partition-2",
weight: 20,
name: "update weight with round-robin load balancer",
domainID: "valid-domain",
taskList: types.TaskList{Name: "test-tasklist"},
taskListType: 1,
forwardedFrom: "",
partition: "partition-2",
loadBalancerHints: &types.LoadBalancerHints{
BacklogCount: 20,
RatePerSecond: 2,
},
loadbalancerStrategy: "round-robin",
shouldUpdate: true,
},
{
name: "do nothing when strategy is unsupported",
domainID: "valid-domain",
taskList: types.TaskList{Name: "test-tasklist"},
taskListType: 1,
forwardedFrom: "",
partition: "partition-3",
weight: 30,
name: "do nothing when strategy is unsupported",
domainID: "valid-domain",
taskList: types.TaskList{Name: "test-tasklist"},
taskListType: 1,
forwardedFrom: "",
partition: "partition-3",
loadBalancerHints: &types.LoadBalancerHints{
BacklogCount: 30,
RatePerSecond: 3,
},
loadbalancerStrategy: "invalid-strategy",
shouldUpdate: false,
},
Expand All @@ -289,7 +298,7 @@ func TestMultiLoadBalancer_UpdateWeight(t *testing.T) {
roundRobinMock := NewMockLoadBalancer(ctrl)

if tt.shouldUpdate {
roundRobinMock.EXPECT().UpdateWeight(tt.domainID, tt.taskList, tt.taskListType, tt.forwardedFrom, tt.partition, tt.weight).Times(1)
roundRobinMock.EXPECT().UpdateWeight(tt.domainID, tt.taskList, tt.taskListType, tt.forwardedFrom, tt.partition, tt.loadBalancerHints).Times(1)
} else {
roundRobinMock.EXPECT().UpdateWeight(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Times(0)
}
Expand All @@ -310,7 +319,7 @@ func TestMultiLoadBalancer_UpdateWeight(t *testing.T) {
}

// Call UpdateWeight
lb.UpdateWeight(tt.domainID, tt.taskList, tt.taskListType, tt.forwardedFrom, tt.partition, tt.weight)
lb.UpdateWeight(tt.domainID, tt.taskList, tt.taskListType, tt.forwardedFrom, tt.partition, tt.loadBalancerHints)
})
}
}
2 changes: 1 addition & 1 deletion client/matching/rr_loadbalancer.go
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,6 @@ func (lb *roundRobinLoadBalancer) UpdateWeight(
taskListType int,
forwardedFrom string,
partition string,
weight int64,
info *types.LoadBalancerHints,
) {
}
28 changes: 26 additions & 2 deletions client/matching/weighted_loadbalancer.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
package matching

import (
"math"
"math/rand"
"path"
"sort"
Expand Down Expand Up @@ -112,6 +113,12 @@ func (pw *weightSelector) update(n, p int, weight int64) {
pw.initialized = true
}

func (pw *weightSelector) getWeights() []int64 {
pw.RLock()
defer pw.RUnlock()
return pw.weights
}

func NewWeightedLoadBalancer(
lb LoadBalancer,
provider PartitionConfigProvider,
Expand Down Expand Up @@ -179,14 +186,17 @@ func (lb *weightedLoadBalancer) UpdateWeight(
taskListType int,
forwardedFrom string,
partition string,
weight int64,
info *types.LoadBalancerHints,
) {
if forwardedFrom != "" || taskList.GetKind() == types.TaskListKindSticky {
return
}
if strings.HasPrefix(taskList.GetName(), common.ReservedTaskListPrefix) {
return
}
if info == nil {
return
}
p := 0
if partition != taskList.GetName() {
var err error
Expand Down Expand Up @@ -218,6 +228,20 @@ func (lb *weightedLoadBalancer) UpdateWeight(
if !ok {
return
}
lb.logger.Debug("update tasklist partition weight", tag.WorkflowDomainID(domainID), tag.WorkflowTaskListName(taskList.GetName()), tag.WorkflowTaskListType(taskListType), tag.Dynamic("weights", w.weights), tag.Dynamic("tasklist-partition", p), tag.Dynamic("weight", weight))
weight := calcWeightFromLoadBalancerHints(info)
lb.logger.Debug("update tasklist partition weight", tag.WorkflowDomainID(domainID), tag.WorkflowTaskListName(taskList.GetName()), tag.WorkflowTaskListType(taskListType), tag.Dynamic("weights", w.getWeights()), tag.Dynamic("tasklist-partition", p), tag.Dynamic("weight", weight), tag.Dynamic("load-balancer-hints", info))
w.update(n, p, weight)
}

func calcWeightFromLoadBalancerHints(info *types.LoadBalancerHints) int64 {
// according to Little's Law, the average number of tasks in the queue L = λW
// where λ is the average arrival rate and W is the average wait time a task spends in the queue
// here λ is the QPS and W is the average match latency which is 10ms
// so the backlog hint should be backlog count + L.
smoothingNumber := int64(0)
qps := info.RatePerSecond
if qps > 0.01 {
smoothingNumber = int64(math.Ceil(qps * 0.01))
}
return info.BacklogCount + smoothingNumber
}
75 changes: 62 additions & 13 deletions client/matching/weighted_loadbalancer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
package matching

import (
"math"
"math/rand"
"testing"
"time"
Expand Down Expand Up @@ -250,14 +251,14 @@ func TestWeightedLoadBalancer_PickReadPartition(t *testing.T) {

func TestWeightedLoadBalancer_UpdateWeight(t *testing.T) {
testCases := []struct {
name string
domainID string
taskList types.TaskList
taskListType int
forwardedFrom string
partition string
weight int64
setupMock func(*cache.MockCache, *MockPartitionConfigProvider)
name string
domainID string
taskList types.TaskList
taskListType int
forwardedFrom string
partition string
loadBalancerHints *types.LoadBalancerHints
setupMock func(*cache.MockCache, *MockPartitionConfigProvider)
}{
{
name: "Sticky task list",
Expand All @@ -276,15 +277,18 @@ func TestWeightedLoadBalancer_UpdateWeight(t *testing.T) {
taskList: types.TaskList{Name: "/__cadence_sys/aaa/1"},
},
{
name: "domain Name lookup error",
domainID: "invalid-domainID",
name: "nil loadBalancerHints",
domainID: "domainA",
taskList: types.TaskList{Name: "a"},
},
{
name: "1 partition",
domainID: "domainA",
taskList: types.TaskList{Name: "a"},
partition: "a",
loadBalancerHints: &types.LoadBalancerHints{
BacklogCount: 1,
},
setupMock: func(mockCache *cache.MockCache, mockPartitionConfigProvider *MockPartitionConfigProvider) {
mockPartitionConfigProvider.EXPECT().GetNumberOfReadPartitions("domainA", types.TaskList{Name: "a"}, 0).Return(1)
mockCache.EXPECT().Delete(key{
Expand All @@ -299,7 +303,9 @@ func TestWeightedLoadBalancer_UpdateWeight(t *testing.T) {
domainID: "domainA",
taskList: types.TaskList{Name: "a"},
partition: "a",
weight: 1,
loadBalancerHints: &types.LoadBalancerHints{
BacklogCount: 1,
},
setupMock: func(mockCache *cache.MockCache, mockPartitionConfigProvider *MockPartitionConfigProvider) {
mockPartitionConfigProvider.EXPECT().GetNumberOfReadPartitions("domainA", types.TaskList{Name: "a"}, 0).Return(2)
mockCache.EXPECT().Get(key{
Expand All @@ -319,7 +325,9 @@ func TestWeightedLoadBalancer_UpdateWeight(t *testing.T) {
domainID: "domainA",
taskList: types.TaskList{Name: "a"},
partition: "/__cadence_sys/a/1",
weight: 1,
loadBalancerHints: &types.LoadBalancerHints{
BacklogCount: 1,
},
setupMock: func(mockCache *cache.MockCache, mockPartitionConfigProvider *MockPartitionConfigProvider) {
mockPartitionConfigProvider.EXPECT().GetNumberOfReadPartitions("domainA", types.TaskList{Name: "a"}, 0).Return(2)
mockCache.EXPECT().Get(key{
Expand All @@ -345,7 +353,48 @@ func TestWeightedLoadBalancer_UpdateWeight(t *testing.T) {
tc.setupMock(mockWeightCache, mockPartitionConfigProvider)
}

lb.UpdateWeight(tc.domainID, tc.taskList, tc.taskListType, tc.forwardedFrom, tc.partition, tc.weight)
lb.UpdateWeight(tc.domainID, tc.taskList, tc.taskListType, tc.forwardedFrom, tc.partition, tc.loadBalancerHints)
})
}
}

func TestCalcWeightFromLoadBalancerHints(t *testing.T) {
tests := []struct {
name string
info types.LoadBalancerHints
expected int64
}{
{
name: "Zero QPS and backlog count",
info: types.LoadBalancerHints{BacklogCount: 0, RatePerSecond: 0},
expected: 0,
},
{
name: "Small QPS below threshold",
info: types.LoadBalancerHints{BacklogCount: 10, RatePerSecond: 0.005},
expected: 10,
},
{
name: "QPS above threshold with no backlog",
info: types.LoadBalancerHints{BacklogCount: 0, RatePerSecond: 2},
expected: int64(math.Ceil(2 * 0.01)), // smoothingNumber calculation
},
{
name: "QPS above threshold with backlog",
info: types.LoadBalancerHints{BacklogCount: 100, RatePerSecond: 5},
expected: 100 + int64(math.Ceil(5*0.01)), // backlog + smoothingNumber
},
{
name: "Large QPS",
info: types.LoadBalancerHints{BacklogCount: 50, RatePerSecond: 100},
expected: 50 + int64(math.Ceil(100*0.01)), // backlog + smoothingNumber
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := calcWeightFromLoadBalancerHints(&tt.info)
assert.Equal(t, tt.expected, result, "unexpected result for %s", tt.name)
})
}
}
Loading

0 comments on commit 97068e7

Please sign in to comment.