Skip to content

Commit

Permalink
Merge pull request #66 from zonghaishang/loadbalance_round_robin
Browse files Browse the repository at this point in the history
Loadbalance round robin
  • Loading branch information
zonghaishang committed Jun 13, 2019
2 parents f16fce4 + d081c07 commit 28d3c8d
Show file tree
Hide file tree
Showing 4 changed files with 224 additions and 1 deletion.
153 changes: 153 additions & 0 deletions cluster/loadbalance/round_robin.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package loadbalance

import (
"math"
"sync"
"sync/atomic"
"time"
)

import (
"github.com/apache/dubbo-go/cluster"
"github.com/apache/dubbo-go/common/extension"
"github.com/apache/dubbo-go/protocol"
)

const (
RoundRobin = "roundrobin"

COMPLETE = 0
UPDATING = 1
)

var (
methodWeightMap sync.Map // [string]invokers
state int32 = COMPLETE // update lock acquired ?
recyclePeriod int64 = 60 * time.Second.Nanoseconds()
)

func init() {
extension.SetLoadbalance(RoundRobin, NewRoundRobinLoadBalance)
}

type roundRobinLoadBalance struct{}

func NewRoundRobinLoadBalance() cluster.LoadBalance {
return &roundRobinLoadBalance{}
}

func (lb *roundRobinLoadBalance) Select(invokers []protocol.Invoker, invocation protocol.Invocation) protocol.Invoker {
count := len(invokers)
if count == 0 {
return nil
}
if count == 1 {
return invokers[0]
}

key := invokers[0].GetUrl().Path + "." + invocation.MethodName()
cache, _ := methodWeightMap.LoadOrStore(key, &cachedInvokers{})
cachedInvokers := cache.(*cachedInvokers)

var (
clean = false
totalWeight = int64(0)
maxCurrentWeight = int64(math.MinInt64)
now = time.Now()
selectedInvoker protocol.Invoker
selectedWeightRobin *weightedRoundRobin
)

for _, invoker := range invokers {
var weight = GetWeight(invoker, invocation)
if weight < 0 {
weight = 0
}

identifier := invoker.GetUrl().Key()
loaded, found := cachedInvokers.LoadOrStore(identifier, &weightedRoundRobin{weight: weight})
weightRobin := loaded.(*weightedRoundRobin)
if !found {
clean = true
}

if weightRobin.Weight() != weight {
weightRobin.setWeight(weight)
}

currentWeight := weightRobin.increaseCurrent()
weightRobin.lastUpdate = &now

if currentWeight > maxCurrentWeight {
maxCurrentWeight = currentWeight
selectedInvoker = invoker
selectedWeightRobin = weightRobin
}
totalWeight += weight
}

cleanIfRequired(clean, cachedInvokers, &now)

if selectedWeightRobin != nil {
selectedWeightRobin.Current(totalWeight)
return selectedInvoker
}

// should never happen
return invokers[0]
}

func cleanIfRequired(clean bool, invokers *cachedInvokers, now *time.Time) {
if clean && atomic.CompareAndSwapInt32(&state, COMPLETE, UPDATING) {
defer atomic.CompareAndSwapInt32(&state, UPDATING, COMPLETE)
invokers.Range(func(identify, robin interface{}) bool {
weightedRoundRobin := robin.(*weightedRoundRobin)
elapsed := now.Sub(*weightedRoundRobin.lastUpdate).Nanoseconds()
if elapsed > recyclePeriod {
invokers.Delete(identify)
}
return true
})
}
}

// Record the weight of the invoker
type weightedRoundRobin struct {
weight int64
current int64
lastUpdate *time.Time
}

func (robin *weightedRoundRobin) Weight() int64 {
return atomic.LoadInt64(&robin.weight)
}

func (robin *weightedRoundRobin) setWeight(weight int64) {
robin.weight = weight
robin.current = 0
}

func (robin *weightedRoundRobin) increaseCurrent() int64 {
return atomic.AddInt64(&robin.current, robin.weight)
}

func (robin *weightedRoundRobin) Current(delta int64) {
atomic.AddInt64(&robin.current, -1*delta)
}

type cachedInvokers struct {
sync.Map /*[string]weightedRoundRobin*/
}
59 changes: 59 additions & 0 deletions cluster/loadbalance/round_robin_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
package loadbalance

import (
"context"
"fmt"
"strconv"
"testing"
)

import (
"github.com/stretchr/testify/assert"
)

import (
"github.com/apache/dubbo-go/common"
"github.com/apache/dubbo-go/protocol"
"github.com/apache/dubbo-go/protocol/invocation"
)

func TestRoundRobinSelect(t *testing.T) {
loadBalance := NewRoundRobinLoadBalance()

var invokers []protocol.Invoker

url, _ := common.NewURL(context.TODO(), "dubbo://192.168.1.0:20000/org.apache.demo.HelloService")
invokers = append(invokers, protocol.NewBaseInvoker(url))
i := loadBalance.Select(invokers, &invocation.RPCInvocation{})
assert.True(t, i.GetUrl().URLEqual(url))

for i := 1; i < 10; i++ {
url, _ := common.NewURL(context.TODO(), fmt.Sprintf("dubbo://192.168.1.%v:20000/org.apache.demo.HelloService", i))
invokers = append(invokers, protocol.NewBaseInvoker(url))
}
loadBalance.Select(invokers, &invocation.RPCInvocation{})
}

func TestRoundRobinByWeight(t *testing.T) {
loadBalance := NewRoundRobinLoadBalance()

var invokers []protocol.Invoker
loop := 10
for i := 1; i <= loop; i++ {
url, _ := common.NewURL(context.TODO(), fmt.Sprintf("dubbo://192.168.1.%v:20000/org.apache.demo.HelloService?weight=%v", i, i))
invokers = append(invokers, protocol.NewBaseInvoker(url))
}

loop = (1 + loop) * loop / 2
selected := make(map[protocol.Invoker]int)

for i := 1; i <= loop; i++ {
invoker := loadBalance.Select(invokers, &invocation.RPCInvocation{})
selected[invoker]++
}

for _, i := range invokers {
w, _ := strconv.Atoi(i.GetUrl().GetParam("weight", "-1"))
assert.True(t, selected[i] == w)
}
}
3 changes: 2 additions & 1 deletion cluster/loadbalance/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ import (

func GetWeight(invoker protocol.Invoker, invocation protocol.Invocation) int64 {
url := invoker.GetUrl()
weight := url.GetMethodParamInt(invocation.MethodName(), constant.WEIGHT_KEY, constant.DEFAULT_WEIGHT)
weight := url.GetMethodParamInt64(invocation.MethodName(), constant.WEIGHT_KEY, constant.DEFAULT_WEIGHT)

if weight > 0 {
//get service register time an do warm up time
now := time.Now().Unix()
Expand Down
10 changes: 10 additions & 0 deletions common/url.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package common
import (
"context"
"fmt"
"math"
"net"
"net/url"
"strconv"
Expand Down Expand Up @@ -288,6 +289,15 @@ func (c URL) GetMethodParamInt(method string, key string, d int64) int64 {
return int64(r)
}

func (c URL) GetMethodParamInt64(method string, key string, d int64) int64 {
r := c.GetMethodParamInt(method, key, math.MinInt64)
if r == math.MinInt64 {
return c.GetParamInt(key, d)
}

return r
}

func (c URL) GetMethodParam(method string, key string, d string) string {
var r string
if r = c.Params.Get("methods." + method + "." + key); r == "" {
Expand Down

0 comments on commit 28d3c8d

Please sign in to comment.