Skip to content
This repository has been archived by the owner on Feb 6, 2024. It is now read-only.

Commit

Permalink
refactor: refactor by cr
Browse files Browse the repository at this point in the history
  • Loading branch information
ZuLiangWang committed Mar 30, 2023
1 parent f8b75dc commit 5914044
Show file tree
Hide file tree
Showing 2 changed files with 112 additions and 56 deletions.
115 changes: 79 additions & 36 deletions server/coordinator/watch/watch.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ package coordinator
import (
"context"
"fmt"
"github.com/CeresDB/ceresdbproto/golang/pkg/metaeventpb"
"google.golang.org/protobuf/proto"
"strconv"
"strings"
"sync"
Expand All @@ -17,31 +19,42 @@ import (
"go.uber.org/zap"
)

type ShardEventType uint64

const (
shardPath = "shards"
EventDelete ShardEventType = 0
EventPut = 1
shardPath = "shards"
keySep = "/"
)

type ShardRegisterEvent struct {
ShardID storage.ShardID
NewLeaderNode string
}

type ShardExpireEvent struct {
ShardID storage.ShardID
OldLeaderNode string
}

type ShardEventCallback interface {
OnShardRegistered(event ShardRegisterEvent) error
OnShardExpired(event ShardExpireEvent) error
}

// ShardWatch used to watch the distributed lock of shard, and provide the corresponding callback function.
type ShardWatch struct {
rootPath string
etcdClient *clientv3.Client
eventCallbacks []func(shardID storage.ShardID, nodeName string, eventType ShardEventType) error
eventCallbacks []ShardEventCallback

lock sync.RWMutex
isRunning bool
quit chan bool
cancel context.CancelFunc
}

func NewWatch(rootPath string, client *clientv3.Client) *ShardWatch {
return &ShardWatch{
rootPath: rootPath,
etcdClient: client,
eventCallbacks: []func(shardID storage.ShardID, nodeName string, eventType ShardEventType) error{},
quit: make(chan bool),
eventCallbacks: []ShardEventCallback{},
}
}

Expand All @@ -53,9 +66,8 @@ func (w *ShardWatch) Start(ctx context.Context) error {
return nil
}

shardsKeyPath := strings.Join([]string{w.rootPath, shardPath}, "/")

if err := w.startWatch(ctx, shardsKeyPath); err != nil {
shardKeyPrefix := encodeShardKeyPrefix(w.rootPath, shardPath)
if err := w.startWatch(ctx, shardKeyPrefix); err != nil {
return errors.WithMessage(err, "etcd register watch failed")
}

Expand All @@ -68,29 +80,24 @@ func (w *ShardWatch) Stop(ctx context.Context) error {
defer w.lock.Unlock()

w.isRunning = false
w.quit <- true
w.cancel()
return nil
}

func (w *ShardWatch) RegisteringEventCallback(eventCallback func(shardID storage.ShardID, nodeName string, eventType ShardEventType) error) {
func (w *ShardWatch) RegisteringEventCallback(eventCallback ShardEventCallback) {
w.eventCallbacks = append(w.eventCallbacks, eventCallback)
}

func (w *ShardWatch) startWatch(ctx context.Context, path string) error {
log.Info("register shard watch", zap.String("watchPath", path))
go func() {
for {
select {
case <-w.quit:
return
default:
respChan := w.etcdClient.Watch(ctx, path, clientv3.WithPrefix(), clientv3.WithPrevKV())
for resp := range respChan {
for _, event := range resp.Events {
if err := w.processEvent(event); err != nil {
log.Error("process event", zap.Error(err))
}
}
ctxWithCancel, cancel := context.WithCancel(ctx)
w.cancel = cancel
respChan := w.etcdClient.Watch(ctxWithCancel, path, clientv3.WithPrefix(), clientv3.WithPrevKV())
for resp := range respChan {
for _, event := range resp.Events {
if err := w.processEvent(event); err != nil {
log.Error("process event", zap.Error(err))
}
}
}
Expand All @@ -101,31 +108,67 @@ func (w *ShardWatch) startWatch(ctx context.Context, path string) error {
func (w *ShardWatch) processEvent(event *clientv3.Event) error {
switch event.Type {
case mvccpb.DELETE:
pathList := strings.Split(string(event.Kv.Key), "/")
shardID, err := strconv.ParseUint(pathList[len(pathList)-1], 10, 64)
shardID, err := decodeShardKey(string(event.Kv.Key))
if err != nil {
return err
}
shardLockValue, err := convertShardLockValueToPB(event.PrevKv.Value)
if err != nil {
return err
}
oldLeader := string(event.PrevKv.Value)
log.Info("receive delete event", zap.String("preKV", fmt.Sprintf("%v", event.PrevKv)), zap.String("event", fmt.Sprintf("%v", event)), zap.Uint64("shardID", shardID), zap.String("oldLeader", oldLeader))
log.Info("receive delete event", zap.String("preKV", fmt.Sprintf("%v", event.PrevKv)), zap.String("event", fmt.Sprintf("%v", event)), zap.Uint64("shardID", shardID), zap.String("oldLeader", shardLockValue.NodeName))
for _, callback := range w.eventCallbacks {
if err := callback(storage.ShardID(shardID), oldLeader, EventDelete); err != nil {
if err := callback.OnShardExpired(ShardExpireEvent{
ShardID: storage.ShardID(shardID),
OldLeaderNode: shardLockValue.NodeName,
}); err != nil {
return err
}
}
case mvccpb.PUT:
pathList := strings.Split(string(event.Kv.Key), "/")
shardID, err := strconv.ParseUint(pathList[len(pathList)-1], 10, 64)
shardID, err := decodeShardKey(string(event.Kv.Key))
if err != nil {
return err
}
shardLockValue, err := convertShardLockValueToPB(event.Kv.Value)
if err != nil {
return err
}
newLeader := string(event.Kv.Value)
log.Info("receive put event", zap.String("event", fmt.Sprintf("%v", event)), zap.Uint64("shardID", shardID), zap.String("oldLeader", newLeader))
log.Info("receive put event", zap.String("event", fmt.Sprintf("%v", event)), zap.Uint64("shardID", shardID), zap.String("oldLeader", shardLockValue.NodeName))
for _, callback := range w.eventCallbacks {
if err := callback(storage.ShardID(shardID), newLeader, EventPut); err != nil {
if err := callback.OnShardRegistered(ShardRegisterEvent{
ShardID: storage.ShardID(shardID),
NewLeaderNode: shardLockValue.NodeName,
}); err != nil {
return err
}
}
}
return nil
}

func decodeShardKey(keyPath string) (uint64, error) {
pathList := strings.Split(keyPath, keySep)
shardID, err := strconv.ParseUint(pathList[len(pathList)-1], 10, 64)
if err != nil {
return 0, errors.WithMessage(err, "decode etcd event key failed")
}
return shardID, nil
}

func encodeShardKeyPrefix(rootPath, shardPath string) string {
return strings.Join([]string{rootPath, shardPath}, keySep)
}

func encodeShardKey(rootPath string, shardPath string, shardID uint64) string {
shardKeyPrefix := encodeShardKeyPrefix(rootPath, shardPath)
return strings.Join([]string{shardKeyPrefix, strconv.FormatUint(shardID, 10)}, keySep)
}

func convertShardLockValueToPB(value []byte) (metaeventpb.ShardLockValue, error) {
shardLockValue := metaeventpb.ShardLockValue{}
if err := proto.Unmarshal(value, &shardLockValue); err != nil {
return shardLockValue, errors.WithMessage(err, "unmarshal shardLockValue failed")
}
return shardLockValue, nil
}
53 changes: 33 additions & 20 deletions server/coordinator/watch/watch_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ package coordinator

import (
"context"
"strconv"
"strings"
"github.com/CeresDB/ceresdbproto/golang/pkg/metaeventpb"
"google.golang.org/protobuf/proto"
"testing"
"time"

Expand All @@ -31,31 +31,44 @@ func TestWatch(t *testing.T) {
err := watch.Start(ctx)
re.NoError(err)

callbackResult := 0
testCallback := func(shardID storage.ShardID, nodeName string, eventType ShardEventType) error {
switch eventType {
case EventDelete:
callbackResult = 1
re.Equal(storage.ShardID(TestShardID), shardID)
re.Equal(TestNodeName, nodeName)
case EventPut:
callbackResult = 2
re.Equal(storage.ShardID(TestShardID), shardID)
re.Equal(TestNodeName, nodeName)
}
return nil
testCallback := testShardEventCallback{
result: 0,
re: re,
}
watch.RegisteringEventCallback(testCallback)

watch.RegisteringEventCallback(&testCallback)

// Valid that callback function is executed and the params are as expected.
keyPath := strings.Join([]string{TestRootPath, TestShardPath, strconv.Itoa(TestShardID)}, "/")
_, err = client.Put(ctx, keyPath, TestNodeName)
b, err := proto.Marshal(&metaeventpb.ShardLockValue{NodeName: TestNodeName})
re.NoError(err)

keyPath := encodeShardKey(TestRootPath, TestShardPath, TestShardID)
_, err = client.Put(ctx, keyPath, string(b))
re.NoError(err)
time.Sleep(time.Millisecond * 10)
re.Equal(callbackResult, 2)
re.Equal(2, testCallback.result)

_, err = client.Delete(ctx, keyPath, clientv3.WithPrevKV())
re.NoError(err)
time.Sleep(time.Millisecond * 10)
re.Equal(callbackResult, 1)
re.Equal(1, testCallback.result)
}

type testShardEventCallback struct {
result int
re *require.Assertions
}

func (c *testShardEventCallback) OnShardRegistered(event ShardRegisterEvent) error {
c.result = 2
c.re.Equal(storage.ShardID(TestShardID), event.ShardID)
c.re.Equal(TestNodeName, event.NewLeaderNode)
return nil
}

func (c *testShardEventCallback) OnShardExpired(event ShardExpireEvent) error {
c.result = 1
c.re.Equal(storage.ShardID(TestShardID), event.ShardID)
c.re.Equal(TestNodeName, event.OldLeaderNode)
return nil
}

0 comments on commit 5914044

Please sign in to comment.