diff --git a/cluster/loadbalance/round_robin.go b/cluster/loadbalance/round_robin.go new file mode 100644 index 0000000000..e173e211c3 --- /dev/null +++ b/cluster/loadbalance/round_robin.go @@ -0,0 +1,153 @@ +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package loadbalance + +import ( + "math" + "sync" + "sync/atomic" + "time" +) + +import ( + "github.com/apache/dubbo-go/cluster" + "github.com/apache/dubbo-go/common/extension" + "github.com/apache/dubbo-go/protocol" +) + +const ( + RoundRobin = "roundrobin" + + COMPLETE = 0 + UPDATING = 1 +) + +var ( + methodWeightMap sync.Map // [string]invokers + state int32 = COMPLETE // update lock acquired ? + recyclePeriod int64 = 60 * time.Second.Nanoseconds() +) + +func init() { + extension.SetLoadbalance(RoundRobin, NewRoundRobinLoadBalance) +} + +type roundRobinLoadBalance struct{} + +func NewRoundRobinLoadBalance() cluster.LoadBalance { + return &roundRobinLoadBalance{} +} + +func (lb *roundRobinLoadBalance) Select(invokers []protocol.Invoker, invocation protocol.Invocation) protocol.Invoker { + count := len(invokers) + if count == 0 { + return nil + } + if count == 1 { + return invokers[0] + } + + key := invokers[0].GetUrl().Path + "." + invocation.MethodName() + cache, _ := methodWeightMap.LoadOrStore(key, &cachedInvokers{}) + cachedInvokers := cache.(*cachedInvokers) + + var ( + clean = false + totalWeight = int64(0) + maxCurrentWeight = int64(math.MinInt64) + now = time.Now() + selectedInvoker protocol.Invoker + selectedWeightRobin *weightedRoundRobin + ) + + for _, invoker := range invokers { + var weight = GetWeight(invoker, invocation) + if weight < 0 { + weight = 0 + } + + identifier := invoker.GetUrl().Key() + loaded, found := cachedInvokers.LoadOrStore(identifier, &weightedRoundRobin{weight: weight}) + weightRobin := loaded.(*weightedRoundRobin) + if !found { + clean = true + } + + if weightRobin.Weight() != weight { + weightRobin.setWeight(weight) + } + + currentWeight := weightRobin.increaseCurrent() + weightRobin.lastUpdate = &now + + if currentWeight > maxCurrentWeight { + maxCurrentWeight = currentWeight + selectedInvoker = invoker + selectedWeightRobin = weightRobin + } + totalWeight += weight + } + + cleanIfRequired(clean, cachedInvokers, &now) + + if selectedWeightRobin != nil { + selectedWeightRobin.Current(totalWeight) + return selectedInvoker + } + + // should never happen + return invokers[0] +} + +func cleanIfRequired(clean bool, invokers *cachedInvokers, now *time.Time) { + if clean && atomic.CompareAndSwapInt32(&state, COMPLETE, UPDATING) { + defer atomic.CompareAndSwapInt32(&state, UPDATING, COMPLETE) + invokers.Range(func(identify, robin interface{}) bool { + weightedRoundRobin := robin.(*weightedRoundRobin) + elapsed := now.Sub(*weightedRoundRobin.lastUpdate).Nanoseconds() + if elapsed > recyclePeriod { + invokers.Delete(identify) + } + return true + }) + } +} + +// Record the weight of the invoker +type weightedRoundRobin struct { + weight int64 + current int64 + lastUpdate *time.Time +} + +func (robin *weightedRoundRobin) Weight() int64 { + return atomic.LoadInt64(&robin.weight) +} + +func (robin *weightedRoundRobin) setWeight(weight int64) { + robin.weight = weight + robin.current = 0 +} + +func (robin *weightedRoundRobin) increaseCurrent() int64 { + return atomic.AddInt64(&robin.current, robin.weight) +} + +func (robin *weightedRoundRobin) Current(delta int64) { + atomic.AddInt64(&robin.current, -1*delta) +} + +type cachedInvokers struct { + sync.Map /*[string]weightedRoundRobin*/ +} diff --git a/cluster/loadbalance/round_robin_test.go b/cluster/loadbalance/round_robin_test.go new file mode 100644 index 0000000000..e261884b55 --- /dev/null +++ b/cluster/loadbalance/round_robin_test.go @@ -0,0 +1,59 @@ +package loadbalance + +import ( + "context" + "fmt" + "strconv" + "testing" +) + +import ( + "github.com/stretchr/testify/assert" +) + +import ( + "github.com/apache/dubbo-go/common" + "github.com/apache/dubbo-go/protocol" + "github.com/apache/dubbo-go/protocol/invocation" +) + +func TestRoundRobinSelect(t *testing.T) { + loadBalance := NewRoundRobinLoadBalance() + + var invokers []protocol.Invoker + + url, _ := common.NewURL(context.TODO(), "dubbo://192.168.1.0:20000/org.apache.demo.HelloService") + invokers = append(invokers, protocol.NewBaseInvoker(url)) + i := loadBalance.Select(invokers, &invocation.RPCInvocation{}) + assert.True(t, i.GetUrl().URLEqual(url)) + + for i := 1; i < 10; i++ { + url, _ := common.NewURL(context.TODO(), fmt.Sprintf("dubbo://192.168.1.%v:20000/org.apache.demo.HelloService", i)) + invokers = append(invokers, protocol.NewBaseInvoker(url)) + } + loadBalance.Select(invokers, &invocation.RPCInvocation{}) +} + +func TestRoundRobinByWeight(t *testing.T) { + loadBalance := NewRoundRobinLoadBalance() + + var invokers []protocol.Invoker + loop := 10 + for i := 1; i <= loop; i++ { + url, _ := common.NewURL(context.TODO(), fmt.Sprintf("dubbo://192.168.1.%v:20000/org.apache.demo.HelloService?weight=%v", i, i)) + invokers = append(invokers, protocol.NewBaseInvoker(url)) + } + + loop = (1 + loop) * loop / 2 + selected := make(map[protocol.Invoker]int) + + for i := 1; i <= loop; i++ { + invoker := loadBalance.Select(invokers, &invocation.RPCInvocation{}) + selected[invoker]++ + } + + for _, i := range invokers { + w, _ := strconv.Atoi(i.GetUrl().GetParam("weight", "-1")) + assert.True(t, selected[i] == w) + } +} diff --git a/cluster/loadbalance/util.go b/cluster/loadbalance/util.go index 736952159d..7e0c2e2650 100644 --- a/cluster/loadbalance/util.go +++ b/cluster/loadbalance/util.go @@ -28,7 +28,8 @@ import ( func GetWeight(invoker protocol.Invoker, invocation protocol.Invocation) int64 { url := invoker.GetUrl() - weight := url.GetMethodParamInt(invocation.MethodName(), constant.WEIGHT_KEY, constant.DEFAULT_WEIGHT) + weight := url.GetMethodParamInt64(invocation.MethodName(), constant.WEIGHT_KEY, constant.DEFAULT_WEIGHT) + if weight > 0 { //get service register time an do warm up time now := time.Now().Unix() diff --git a/common/url.go b/common/url.go index 115167ee3e..4fb1af767f 100644 --- a/common/url.go +++ b/common/url.go @@ -20,6 +20,7 @@ package common import ( "context" "fmt" + "math" "net" "net/url" "strconv" @@ -288,6 +289,15 @@ func (c URL) GetMethodParamInt(method string, key string, d int64) int64 { return int64(r) } +func (c URL) GetMethodParamInt64(method string, key string, d int64) int64 { + r := c.GetMethodParamInt(method, key, math.MinInt64) + if r == math.MinInt64 { + return c.GetParamInt(key, d) + } + + return r +} + func (c URL) GetMethodParam(method string, key string, d string) string { var r string if r = c.Params.Get("methods." + method + "." + key); r == "" {