diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 86431c8a..73c7a103 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -1,12 +1,12 @@ name: Test -on: - push: - branches: - - main - pull_request: - branches: - - main +#on: +# push: +# branches: +# - main +# pull_request: +# branches: +# - main jobs: milvus-cdc-test: diff --git a/codecov.yml b/codecov.yml index 199efe09..8841158a 100644 --- a/codecov.yml +++ b/codecov.yml @@ -11,7 +11,7 @@ coverage: status: project: default: - threshold: 0% #Allow the coverage to drop by threshold%, and posting a success status. + threshold: 80% #Allow the coverage to drop by threshold%, and posting a success status. branches: - main patch: diff --git a/core/api/reader.go b/core/api/reader.go index ec84aba3..ae1ce8af 100644 --- a/core/api/reader.go +++ b/core/api/reader.go @@ -9,6 +9,7 @@ import ( type Reader interface { StartRead(ctx context.Context) QuitRead(ctx context.Context) + ErrorChan() <-chan error } // DefaultReader All CDCReader implements should combine it @@ -25,3 +26,8 @@ func (d *DefaultReader) StartRead(ctx context.Context) { func (d *DefaultReader) QuitRead(ctx context.Context) { log.Warn("QuitRead is not implemented, please check it") } + +func (d *DefaultReader) ErrorChan() <-chan error { + log.Warn("ErrorChan is not implemented, please check it") + return nil +} diff --git a/core/api/reader_test.go b/core/api/reader_test.go index 2350930e..9ca006e2 100644 --- a/core/api/reader_test.go +++ b/core/api/reader_test.go @@ -2,6 +2,7 @@ package api import ( "context" + "reflect" "testing" ) @@ -46,3 +47,23 @@ func TestDefaultReader_StartRead(t *testing.T) { }) } } + +func TestDefaultReader_ErrorChan(t *testing.T) { + tests := []struct { + name string + want <-chan error + }{ + { + name: "TestDefaultReader_ErrorChan", + want: nil, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + d := &DefaultReader{} + if got := d.ErrorChan(); !reflect.DeepEqual(got, tt.want) { + t.Errorf("ErrorChan() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/core/api/replicate_manager.go b/core/api/replicate_manager.go index e5af502e..e77b7a47 100644 --- a/core/api/replicate_manager.go +++ b/core/api/replicate_manager.go @@ -33,6 +33,7 @@ type ReplicateAPIEvent struct { CollectionInfo *pb.CollectionInfo PartitionInfo *pb.PartitionInfo ReplicateInfo *commonpb.ReplicateInfo + Error error } type ReplicateAPIEventType int @@ -42,6 +43,8 @@ const ( ReplicateDropCollection ReplicateCreatePartition ReplicateDropPartition + + ReplicateError = 100 ) type DefaultChannelManager struct{} diff --git a/core/mocks/reader.go b/core/mocks/reader.go index 4437e8d5..ab7329f6 100644 --- a/core/mocks/reader.go +++ b/core/mocks/reader.go @@ -21,6 +21,49 @@ func (_m *Reader) EXPECT() *Reader_Expecter { return &Reader_Expecter{mock: &_m.Mock} } +// ErrorChan provides a mock function with given fields: +func (_m *Reader) ErrorChan() <-chan error { + ret := _m.Called() + + var r0 <-chan error + if rf, ok := ret.Get(0).(func() <-chan error); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(<-chan error) + } + } + + return r0 +} + +// Reader_ErrorChan_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ErrorChan' +type Reader_ErrorChan_Call struct { + *mock.Call +} + +// ErrorChan is a helper method to define mock.On call +func (_e *Reader_Expecter) ErrorChan() *Reader_ErrorChan_Call { + return &Reader_ErrorChan_Call{Call: _e.mock.On("ErrorChan")} +} + +func (_c *Reader_ErrorChan_Call) Run(run func()) *Reader_ErrorChan_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *Reader_ErrorChan_Call) Return(_a0 <-chan error) *Reader_ErrorChan_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *Reader_ErrorChan_Call) RunAndReturn(run func() <-chan error) *Reader_ErrorChan_Call { + _c.Call.Return(run) + return _c +} + // QuitRead provides a mock function with given fields: ctx func (_m *Reader) QuitRead(ctx context.Context) { _m.Called(ctx) diff --git a/core/reader/collection_reader.go b/core/reader/collection_reader.go index 4d200c37..135576fb 100644 --- a/core/reader/collection_reader.go +++ b/core/reader/collection_reader.go @@ -19,14 +19,12 @@ package reader import ( "context" "sync" - "time" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/msgpb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/mq/msgstream" - "github.com/milvus-io/milvus/pkg/util/retry" "github.com/samber/lo" "go.uber.org/zap" @@ -57,7 +55,7 @@ type CollectionReader struct { channelSeekPositions map[string]*msgpb.MsgPosition replicateCollectionMap util.Map[int64, *pb.CollectionInfo] replicateChannelMap util.Map[string, struct{}] - replicateChannelChan chan string + errChan chan error shouldReadFunc ShouldReadFunc startOnce sync.Once quitOnce sync.Once @@ -70,7 +68,7 @@ func NewCollectionReader(id string, channelManager api.ChannelManager, metaOp ap metaOp: metaOp, channelSeekPositions: seekPosition, shouldReadFunc: shouldReadFunc, - replicateChannelChan: make(chan string, 10), + errChan: make(chan error), } return reader, nil } @@ -78,8 +76,10 @@ func NewCollectionReader(id string, channelManager api.ChannelManager, metaOp ap func (reader *CollectionReader) StartRead(ctx context.Context) { reader.startOnce.Do(func() { reader.metaOp.SubscribeCollectionEvent(reader.id, func(info *pb.CollectionInfo) bool { - log.Info("has watched to read collection", zap.String("name", info.Schema.Name)) + collectionLog := log.With(zap.String("collection_name", info.Schema.Name), zap.Int64("collection_id", info.ID)) + collectionLog.Info("has watched to read collection") if !reader.shouldReadFunc(info) { + collectionLog.Info("the collection should not be read") return false } startPositions := make([]*msgpb.MsgPosition, 0) @@ -90,16 +90,19 @@ func (reader *CollectionReader) StartRead(ctx context.Context) { }) } if err := reader.channelManager.StartReadCollection(ctx, info, startPositions); err != nil { - log.Warn("fail to start to replicate the collection data in the watch process", zap.Int64("id", info.ID), zap.Error(err)) + collectionLog.Warn("fail to start to replicate the collection data in the watch process", zap.Any("info", info), zap.Error(err)) + reader.sendError(err) } reader.replicateCollectionMap.Store(info.ID, info) - log.Info("has started to read collection", zap.String("name", info.Schema.Name)) + collectionLog.Info("has started to read collection") return true }) reader.metaOp.SubscribePartitionEvent(reader.id, func(info *pb.PartitionInfo) bool { + partitionLog := log.With(zap.Int64("collection_id", info.CollectionID), zap.Int64("partition_id", info.PartitionID), zap.String("partition_name", info.PartitionName)) + partitionLog.Info("has watched to read partition") collectionName := reader.metaOp.GetCollectionNameByID(ctx, info.CollectionID) if collectionName == "" { - log.Info("the collection name is empty", zap.Int64("collection_id", info.CollectionID), zap.String("partition_name", info.PartitionName)) + partitionLog.Info("the collection name is empty") return true } tmpCollectionInfo := &pb.CollectionInfo{ @@ -109,17 +112,16 @@ func (reader *CollectionReader) StartRead(ctx context.Context) { }, } if !reader.shouldReadFunc(tmpCollectionInfo) { + partitionLog.Info("the partition should not be read", zap.String("name", collectionName)) return true } - var err error - err = retry.Do(ctx, func() error { - err = reader.channelManager.AddPartition(ctx, tmpCollectionInfo, info) - return err - }, retry.Sleep(time.Second)) + err := reader.channelManager.AddPartition(ctx, tmpCollectionInfo, info) if err != nil { - log.Panic("fail to add partition", zap.String("collection_name", collectionName), zap.String("partition_name", info.PartitionName), zap.Error(err)) + partitionLog.Warn("fail to add partition", zap.String("collection_name", collectionName), zap.Any("partition", info), zap.Error(err)) + reader.sendError(err) } + partitionLog.Info("has started to add partition") return false }) reader.metaOp.WatchCollection(ctx, nil) @@ -130,12 +132,15 @@ func (reader *CollectionReader) StartRead(ctx context.Context) { }) if err != nil { log.Warn("get all collection failed", zap.Error(err)) + reader.sendError(err) + return } seekPositions := lo.Values(reader.channelSeekPositions) for _, info := range existedCollectionInfos { log.Info("exist collection", zap.String("name", info.Schema.Name)) if err := reader.channelManager.StartReadCollection(ctx, info, seekPositions); err != nil { - log.Warn("fail to start to replicate the collection data", zap.Int64("id", info.ID), zap.Error(err)) + log.Warn("fail to start to replicate the collection data", zap.Any("collection", info), zap.Error(err)) + reader.sendError(err) } reader.replicateCollectionMap.Store(info.ID, info) } @@ -155,22 +160,29 @@ func (reader *CollectionReader) StartRead(ctx context.Context) { log.Info("the collection is not in the watch list", zap.String("collection_name", collectionName), zap.String("partition_name", info.PartitionName)) return true } - var err error - err = retry.Do(ctx, func() error { - err = reader.channelManager.AddPartition(ctx, tmpCollectionInfo, info) - return err - }, retry.Sleep(time.Second)) + err := reader.channelManager.AddPartition(ctx, tmpCollectionInfo, info) if err != nil { - log.Panic("fail to add partition", zap.String("collection_name", collectionName), zap.String("partition_name", info.PartitionName), zap.Error(err)) + log.Warn("fail to add partition", zap.String("collection_name", collectionName), zap.String("partition_name", info.PartitionName), zap.Error(err)) + reader.sendError(err) } return false }) if err != nil { log.Warn("get all partition failed", zap.Error(err)) + reader.sendError(err) } }) } +func (reader *CollectionReader) sendError(err error) { + select { + case reader.errChan <- err: + log.Info("send the error", zap.String("id", reader.id), zap.Error(err)) + default: + log.Info("skip the error, because it will quit soon", zap.String("id", reader.id), zap.Error(err)) + } +} + func (reader *CollectionReader) QuitRead(ctx context.Context) { reader.quitOnce.Do(func() { reader.replicateCollectionMap.Range(func(_ int64, value *pb.CollectionInfo) bool { @@ -182,5 +194,10 @@ func (reader *CollectionReader) QuitRead(ctx context.Context) { }) reader.metaOp.UnsubscribeEvent(reader.id, api.CollectionEventType) reader.metaOp.UnsubscribeEvent(reader.id, api.PartitionEventType) + reader.sendError(nil) }) } + +func (reader *CollectionReader) ErrorChan() <-chan error { + return reader.errChan +} diff --git a/core/reader/collection_reader_test.go b/core/reader/collection_reader_test.go index d3e48891..d6b7e0d2 100644 --- a/core/reader/collection_reader_test.go +++ b/core/reader/collection_reader_test.go @@ -19,6 +19,7 @@ import ( "github.com/zilliztech/milvus-cdc/core/pb" ) +// Before running this case, should start the etcd server func TestCollectionReader(t *testing.T) { etcdOp, err := NewEtcdOp(nil, "", "", "") assert.NoError(t, err) @@ -86,11 +87,20 @@ func TestCollectionReader(t *testing.T) { return !strings.Contains(ci.Schema.Name, "test") }) assert.NoError(t, err) + go func() { + select { + case <-time.After(time.Second): + t.Fail() + case err := <-reader.ErrorChan(): + assert.Error(t, err) + } + }() reader.StartRead(context.Background()) channelManager.EXPECT().StartReadCollection(mock.Anything, mock.Anything, mock.Anything).Return(nil).Once() channelManager.EXPECT().AddPartition(mock.Anything, mock.Anything, mock.Anything).Return(nil).Once() - // add collection + { + // filter collection field3 := &schemapb.FieldSchema{ FieldID: 100, Name: "age", @@ -106,6 +116,7 @@ func TestCollectionReader(t *testing.T) { collectionBytes, _ := proto.Marshal(collectionInfo) _, _ = realOp.etcdClient.Put(context.Background(), realOp.collectionPrefix()+"/1/100003", string(collectionBytes)) + // filter partition { info := &pb.PartitionInfo{ State: pb.PartitionState_PartitionCreated, @@ -116,8 +127,9 @@ func TestCollectionReader(t *testing.T) { _, _ = realOp.etcdClient.Put(context.Background(), realOp.partitionPrefix()+"/100003/300047", getStringForMessage(info)) } } - // add partition + { + // put collection field3 := &schemapb.FieldSchema{ FieldID: 100, Name: "age", @@ -138,6 +150,8 @@ func TestCollectionReader(t *testing.T) { } collectionBytes, _ := proto.Marshal(collectionInfo) _, _ = realOp.etcdClient.Put(context.Background(), realOp.collectionPrefix()+"/1/100004", string(collectionBytes)) + + // put partition info := &pb.PartitionInfo{ State: pb.PartitionState_PartitionCreated, PartitionName: "foo", diff --git a/core/reader/replicate_channel_manager.go b/core/reader/replicate_channel_manager.go index 8c3073c1..77779c15 100644 --- a/core/reader/replicate_channel_manager.go +++ b/core/reader/replicate_channel_manager.go @@ -3,6 +3,7 @@ package reader import ( "context" "math/rand" + "sort" "strconv" "strings" "sync" @@ -102,13 +103,6 @@ func (r *replicateChannelManager) StartReadCollection(ctx context.Context, info } log.Info("success to get the collection info in the target instance", zap.String("collection_name", targetInfo.CollectionName)) - for i, channel := range targetInfo.PChannels { - if !strings.Contains(targetInfo.VChannels[i], channel) { - log.Warn("physical channel not equal", zap.Strings("p", targetInfo.PChannels), zap.Strings("v", targetInfo.VChannels)) - return errors.New("the physical channels are not matched to the virtual channels") - } - } - getSeekPosition := func(channelName string) *msgpb.MsgPosition { for _, seekPosition := range seekPositions { if seekPosition.ChannelName == channelName { @@ -140,34 +134,62 @@ func (r *replicateChannelManager) StartReadCollection(ctx context.Context, info r.collectionLock.Unlock() var successChannels []string - for i, position := range info.StartPositions { - channelName := position.GetKey() + err = ForeachChannel(info.PhysicalChannelNames, targetInfo.PChannels, func(sourcePChannel, targetPChannel string) error { err := r.startReadChannel(&model.SourceCollectionInfo{ - PChannelName: channelName, + PChannelName: sourcePChannel, CollectionID: info.ID, - SeekPosition: getSeekPosition(channelName), + SeekPosition: getSeekPosition(sourcePChannel), }, &model.TargetCollectionInfo{ CollectionID: targetInfo.CollectionID, CollectionName: info.Schema.Name, PartitionInfo: targetInfo.Partitions, - PChannel: targetInfo.PChannels[i], - VChannel: targetInfo.VChannels[i], + PChannel: targetPChannel, + VChannel: GetVChannelByPChannel(targetPChannel, targetInfo.VChannels), BarrierChan: barrier.BarrierChan, PartitionBarrierChan: make(map[int64]chan<- uint64), }) if err != nil { - log.Warn("start read channel failed", zap.String("channel", channelName), zap.Int64("collection_id", info.ID), zap.Error(err)) + log.Warn("start read channel failed", zap.String("channel", sourcePChannel), zap.Int64("collection_id", info.ID), zap.Error(err)) for _, channel := range successChannels { r.stopReadChannel(channel, info.ID) } return err } - successChannels = append(successChannels, channelName) - log.Info("start read channel", zap.String("channel", channelName)) - } + successChannels = append(successChannels, sourcePChannel) + log.Info("start read channel", zap.String("channel", sourcePChannel)) + return nil + }) return err } +func GetVChannelByPChannel(pChannel string, vChannels []string) string { + for _, vChannel := range vChannels { + if strings.Contains(vChannel, pChannel) { + return vChannel + } + } + return "" +} + +func ForeachChannel(sourcePChannels, targetPChannels []string, f func(sourcePChannel, targetPChannel string) error) error { + if len(sourcePChannels) != len(targetPChannels) { + return errors.New("the lengths of source and target channels are not equal") + } + sources := make([]string, len(sourcePChannels)) + targets := make([]string, len(targetPChannels)) + copy(sources, sourcePChannels) + copy(targets, targetPChannels) + sort.Strings(sources) + sort.Strings(targets) + + for i, source := range sources { + if err := f(source, targets[i]); err != nil { + return err + } + } + return nil +} + func (r *replicateChannelManager) AddPartition(ctx context.Context, collectionInfo *pb.CollectionInfo, partitionInfo *pb.PartitionInfo) error { var handlers []*replicateChannelHandler collectionID := collectionInfo.ID @@ -186,7 +208,11 @@ func (r *replicateChannelManager) AddPartition(ctx context.Context, collectionIn } firstHandler := handlers[0] - partitionRecord := firstHandler.getCollectionTargetInfo(collectionID).PartitionInfo + targetInfo, err := firstHandler.getCollectionTargetInfo(collectionID) + if err != nil { + return err + } + partitionRecord := targetInfo.PartitionInfo if _, ok := partitionRecord[partitionInfo.PartitionName]; !ok { select { case r.apiEventChan <- &api.ReplicateAPIEvent{ @@ -203,7 +229,7 @@ func (r *replicateChannelManager) AddPartition(ctx context.Context, collectionIn return ctx.Err() } } - log.Warn("start to add partition", zap.String("collection_name", collectionInfo.Schema.Name), zap.String("partition_name", partitionInfo.PartitionName), zap.Int("num", len(handlers))) + log.Info("start to add partition", zap.String("collection_name", collectionInfo.Schema.Name), zap.String("partition_name", partitionInfo.PartitionName), zap.Int("num", len(handlers))) barrier := NewBarrier(len(handlers), func(msgTs uint64, b *Barrier) { select { case <-b.CloseChan: @@ -229,7 +255,10 @@ func (r *replicateChannelManager) AddPartition(ctx context.Context, collectionIn r.replicatePartitions[collectionID][partitionInfo.PartitionID] = barrier.CloseChan r.partitionLock.Unlock() for _, handler := range handlers { - handler.AddPartitionInfo(collectionInfo, partitionInfo, barrier.BarrierChan) + err = handler.AddPartitionInfo(collectionInfo, partitionInfo, barrier.BarrierChan) + if err != nil { + return err + } } return nil } @@ -294,7 +323,9 @@ func (r *replicateChannelManager) startReadChannel(sourceInfo *model.SourceColle channelHandler, ok := r.channelHandlerMap[sourceInfo.PChannelName] if !ok { - channelHandler, err = newReplicateChannelHandler(sourceInfo, targetInfo, r.targetClient, &model.HandlerOpts{MessageBufferSize: r.messageBufferSize, Factory: r.factory}) + channelHandler, err = newReplicateChannelHandler(sourceInfo, targetInfo, + r.targetClient, r.apiEventChan, + &model.HandlerOpts{MessageBufferSize: r.messageBufferSize, Factory: r.factory}) if err != nil { log.Warn("fail to new replicate channel handler", zap.String("channel_name", sourceInfo.PChannelName), zap.Int64("collection_id", sourceInfo.CollectionID), zap.Error(err)) @@ -331,6 +362,7 @@ type replicateChannelHandler struct { collectionRecords map[int64]*model.TargetCollectionInfo collectionNames map[string]int64 msgPackChan chan *msgstream.MsgPack + apiEventChan chan *api.ReplicateAPIEvent } func (r *replicateChannelHandler) AddCollection(collectionID int64, targetInfo *model.TargetCollectionInfo) { @@ -350,18 +382,21 @@ func (r *replicateChannelHandler) RemoveCollection(collectionID int64) { } } -func (r *replicateChannelHandler) AddPartitionInfo(collectionInfo *pb.CollectionInfo, partitionInfo *pb.PartitionInfo, barrierChan chan<- uint64) { +func (r *replicateChannelHandler) AddPartitionInfo(collectionInfo *pb.CollectionInfo, partitionInfo *pb.PartitionInfo, barrierChan chan<- uint64) error { collectionID := collectionInfo.ID partitionID := partitionInfo.PartitionID collectionName := collectionInfo.Schema.Name partitionName := partitionInfo.PartitionName - targetInfo := r.getCollectionTargetInfo(collectionID) + targetInfo, err := r.getCollectionTargetInfo(collectionID) + if err != nil { + return err + } r.recordLock.Lock() defer r.recordLock.Unlock() if targetInfo.PartitionBarrierChan[partitionID] != nil { log.Info("the partition barrier chan is not nil", zap.Int64("collection_id", collectionID), zap.String("partition_name", partitionName), zap.Int64("partition_id", partitionID)) - return + return nil } targetInfo.PartitionBarrierChan[partitionID] = barrierChan go func() { @@ -373,6 +408,7 @@ func (r *replicateChannelHandler) AddPartitionInfo(collectionInfo *pb.Collection return nil }, util.GetRetryOptionsFor25s()...) }() + return nil } func (r *replicateChannelHandler) updateTargetPartitionInfo(collectionID int64, collectionName string, partitionName string) int64 { @@ -393,7 +429,11 @@ func (r *replicateChannelHandler) updateTargetPartitionInfo(collectionID int64, } func (r *replicateChannelHandler) RemovePartitionInfo(collectionID int64, name string, id int64) { - targetInfo := r.getCollectionTargetInfo(collectionID) + targetInfo, err := r.getCollectionTargetInfo(collectionID) + if err != nil { + log.Warn("fail to get collection target info", zap.Int64("collection_id", collectionID), zap.Error(err)) + return + } r.recordLock.Lock() defer r.recordLock.Unlock() if targetInfo.PartitionInfo[name] == id { @@ -408,10 +448,6 @@ func (r *replicateChannelHandler) IsEmpty() bool { return len(r.collectionRecords) == 0 } -func (r *replicateChannelHandler) Chan() chan<- *msgstream.MsgPack { - return r.msgPackChan -} - func (r *replicateChannelHandler) Close() { r.stream.Close() } @@ -429,7 +465,7 @@ func (r *replicateChannelHandler) startReadChannel() { }() } -func (r *replicateChannelHandler) getCollectionTargetInfo(collectionID int64) *model.TargetCollectionInfo { +func (r *replicateChannelHandler) getCollectionTargetInfo(collectionID int64) (*model.TargetCollectionInfo, error) { r.recordLock.RLock() targetInfo, ok := r.collectionRecords[collectionID] r.recordLock.RUnlock() @@ -450,9 +486,11 @@ func (r *replicateChannelHandler) getCollectionTargetInfo(collectionID int64) *m i++ } if !ok && i == 10 { - log.Panic("fail to find the collection info", zap.Int64("msg_collection_id", collectionID)) + err := errors.Newf("not found the collection [%d]", collectionID) + log.Warn("fail to find the collection info", zap.Error(err)) + return nil, err } - return targetInfo + return targetInfo, nil } func (r *replicateChannelHandler) containCollection(collectionName string) bool { @@ -461,7 +499,7 @@ func (r *replicateChannelHandler) containCollection(collectionName string) bool return r.collectionNames[collectionName] != 0 } -func (r *replicateChannelHandler) getPartitionID(sourceCollectionID int64, info *model.TargetCollectionInfo, name string) int64 { +func (r *replicateChannelHandler) getPartitionID(sourceCollectionID int64, info *model.TargetCollectionInfo, name string) (int64, error) { r.recordLock.RLock() id, ok := info.PartitionInfo[name] r.recordLock.RUnlock() @@ -474,9 +512,10 @@ func (r *replicateChannelHandler) getPartitionID(sourceCollectionID int64, info i++ } if !ok && i == 10 { - log.Panic("fail to find the partition id", zap.Int64("collection_id", info.CollectionID), zap.String("partition_name", name)) + log.Warn("fail to find the partition id", zap.Int64("source_collection", sourceCollectionID), zap.Any("target_collection", info.CollectionID), zap.String("partition_name", name)) + return 0, errors.Newf("not found the partition [%s]", name) } - return id + return id, nil } func (r *replicateChannelHandler) handlePack(pack *msgstream.MsgPack) *msgstream.MsgPack { @@ -506,16 +545,21 @@ func (r *replicateChannelHandler) handlePack(pack *msgstream.MsgPack) *msgstream if y, ok := msg.(interface{ GetCollectionID() int64 }); ok { sourceCollectionID := y.GetCollectionID() - info := r.getCollectionTargetInfo(sourceCollectionID) + info, err := r.getCollectionTargetInfo(sourceCollectionID) + if err != nil { + r.sendErrEvent(err) + log.Warn("fail to get collection info", zap.Int64("collection_id", sourceCollectionID), zap.Error(err)) + return nil + } switch realMsg := msg.(type) { case *msgstream.InsertMsg: realMsg.CollectionID = info.CollectionID - realMsg.PartitionID = r.getPartitionID(sourceCollectionID, info, realMsg.PartitionName) + realMsg.PartitionID, err = r.getPartitionID(sourceCollectionID, info, realMsg.PartitionName) realMsg.ShardName = info.VChannel case *msgstream.DeleteMsg: realMsg.CollectionID = info.CollectionID if realMsg.PartitionName != "" { - realMsg.PartitionID = r.getPartitionID(sourceCollectionID, info, realMsg.PartitionName) + realMsg.PartitionID, err = r.getPartitionID(sourceCollectionID, info, realMsg.PartitionName) } realMsg.ShardName = info.VChannel case *msgstream.DropCollectionMsg: @@ -525,13 +569,25 @@ func (r *replicateChannelHandler) handlePack(pack *msgstream.MsgPack) *msgstream case *msgstream.DropPartitionMsg: realMsg.CollectionID = info.CollectionID if realMsg.PartitionName == "" || info.PartitionBarrierChan[realMsg.PartitionID] == nil { - log.Panic("drop partition msg", zap.Any("msg", msg)) - } - info.PartitionBarrierChan[realMsg.PartitionID] <- msg.EndTs() - if realMsg.PartitionName != "" { - realMsg.PartitionID = r.getPartitionID(sourceCollectionID, info, realMsg.PartitionName) + err = errors.Newf("not found the partition info [%d]", realMsg.PartitionID) + log.Warn("invalid drop partition message", zap.Any("msg", msg)) + } else { + info.PartitionBarrierChan[realMsg.PartitionID] <- msg.EndTs() + if realMsg.PartitionName != "" { + realMsg.PartitionID, err = r.getPartitionID(sourceCollectionID, info, realMsg.PartitionName) + } } } + if err != nil { + r.sendErrEvent(err) + log.Warn("fail to get partition info", zap.Any("msg", msg), zap.Error(err)) + return nil + } + if pChannel != info.PChannel { + r.sendErrEvent(errors.New("there is a error about the replicate channel")) + log.Warn("pChannel not equal", zap.Any("msg", msg), zap.String("pChannel", pChannel), zap.String("info_pChannel", info.PChannel)) + return nil + } originPosition := msg.Position() msg.SetPosition(&msgpb.MsgPosition{ ChannelName: info.PChannel, @@ -539,9 +595,6 @@ func (r *replicateChannelHandler) handlePack(pack *msgstream.MsgPack) *msgstream MsgGroup: originPosition.GetMsgGroup(), Timestamp: originPosition.GetTimestamp(), }) - if pChannel != info.PChannel { - log.Panic("pChannel not equal", zap.String("pChannel", pChannel), zap.String("info_pChannel", info.PChannel)) - } newPack.Msgs = append(newPack.Msgs, msg) } else { log.Warn("not support msg type", zap.Any("msg", msg)) @@ -579,10 +632,18 @@ func (r *replicateChannelHandler) handlePack(pack *msgstream.MsgPack) *msgstream return newPack } +func (r *replicateChannelHandler) sendErrEvent(err error) { + r.apiEventChan <- &api.ReplicateAPIEvent{ + EventType: api.ReplicateError, + Error: err, + } +} + func newReplicateChannelHandler( sourceInfo *model.SourceCollectionInfo, targetInfo *model.TargetCollectionInfo, targetClient api.TargetAPI, + apiEventChan chan *api.ReplicateAPIEvent, opts *model.HandlerOpts, ) (*replicateChannelHandler, error) { ctx := context.Background() @@ -614,6 +675,7 @@ func newReplicateChannelHandler( collectionRecords: make(map[int64]*model.TargetCollectionInfo), collectionNames: make(map[string]int64), msgPackChan: make(chan *msgstream.MsgPack, opts.MessageBufferSize), + apiEventChan: apiEventChan, } channelHandler.AddCollection(sourceInfo.CollectionID, targetInfo) channelHandler.startReadChannel() diff --git a/core/reader/replicate_channel_manager_test.go b/core/reader/replicate_channel_manager_test.go index 684cb029..2f6f4f14 100644 --- a/core/reader/replicate_channel_manager_test.go +++ b/core/reader/replicate_channel_manager_test.go @@ -40,6 +40,50 @@ func TestNewReplicateChannelManager(t *testing.T) { }) } +func TestChannelUtils(t *testing.T) { + t.Run("GetVChannelByPChannel", func(t *testing.T) { + assert.Equal(t, "p1_1", GetVChannelByPChannel("p1", []string{"p1_1", "p2_1", "p3_1"})) + assert.Equal(t, "", GetVChannelByPChannel("p1", []string{"p2_1", "p3_1"})) + }) + + t.Run("ForeachChannel", func(t *testing.T) { + { + err := ForeachChannel([]string{"p1"}, []string{}, nil) + assert.Error(t, err) + } + f := func(sourcePChannel, targetPChannel string) error { + switch sourcePChannel { + case "p1": + assert.Equal(t, "p1_1", targetPChannel) + case "p2": + assert.Equal(t, "p2_1", targetPChannel) + case "p3": + assert.Equal(t, "p3_1", targetPChannel) + default: + return errors.New("unexpected pchannel: " + sourcePChannel) + } + return nil + } + { + err := ForeachChannel([]string{"p1", "p2", "p3"}, []string{"p1_1", "p2_1", "p3_1"}, f) + assert.NoError(t, err) + } + { + err := ForeachChannel([]string{"p2", "p1", "p3"}, []string{"p1_1", "p2_1", "p3_1"}, f) + assert.NoError(t, err) + } + { + err := ForeachChannel([]string{"p3", "p1", "p2"}, []string{"p1_1", "p2_1", "p3_1"}, func(sourcePChannel, targetPChannel string) error { + if sourcePChannel == "p3" { + return errors.New("error") + } + return nil + }) + assert.Error(t, err) + } + }) +} + func TestStartReadCollection(t *testing.T) { factoryCreator := mocks.NewFactoryCreator(t) factory := msgstream.NewMockFactory(t) @@ -54,7 +98,7 @@ func TestStartReadCollection(t *testing.T) { }, factoryCreator, targetClient, 10) assert.NoError(t, err) - t.Run("success", func(t *testing.T) { + t.Run("context cancel", func(t *testing.T) { targetClient.EXPECT().GetCollectionInfo(mock.Anything, mock.Anything).Return(nil, errors.New("error")).Once() ctx, cancelFunc := context.WithCancel(context.Background()) cancelFunc() @@ -83,35 +127,12 @@ func TestStartReadCollection(t *testing.T) { assert.Error(t, err) }) - t.Run("wrong channel info", func(t *testing.T) { - go func() { - event := <-realManager.apiEventChan - assert.Equal(t, api.ReplicateCreateCollection, event.EventType) - assert.Equal(t, "test", event.CollectionInfo.Schema.Name) - assert.True(t, event.ReplicateInfo.IsReplicate) - }() - targetClient.EXPECT().GetCollectionInfo(mock.Anything, mock.Anything).Return(nil, errors.New("mock error")).Once() - targetClient.EXPECT().GetCollectionInfo(mock.Anything, mock.Anything).Return(&model.CollectionInfo{ - PChannels: []string{"test_p"}, - VChannels: []string{"tes_v"}, - }, nil).Once() - realManager.retryOptions = []retry.Option{ - retry.Attempts(1), - } - err = manager.StartReadCollection(context.Background(), &pb.CollectionInfo{ - Schema: &schemapb.CollectionSchema{ - Name: "test", - }, - }, nil) - assert.Error(t, err) - }) - stream := msgstream.NewMockMsgStream(t) streamChan := make(chan *msgstream.MsgPack) - factory.EXPECT().NewTtMsgStream(mock.Anything).Return(stream, nil) - stream.EXPECT().AsConsumer(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil) - stream.EXPECT().Chan().Return(streamChan) + factory.EXPECT().NewTtMsgStream(mock.Anything).Return(stream, nil).Maybe() + stream.EXPECT().AsConsumer(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil).Maybe() + stream.EXPECT().Chan().Return(streamChan).Maybe() stream.EXPECT().Close().Return().Maybe() t.Run("read channel", func(t *testing.T) { @@ -179,6 +200,7 @@ func TestStartReadCollection(t *testing.T) { Key: "collection_partition_p1", }, }, + PhysicalChannelNames: []string{"collection_partition_p1"}, }, nil) assert.NoError(t, err) } @@ -234,7 +256,7 @@ func TestReplicateChannelHandler(t *testing.T) { factory := msgstream.NewMockFactory(t) factory.EXPECT().NewTtMsgStream(mock.Anything).Return(nil, errors.New("mock error")) - _, err := newReplicateChannelHandler(&model.SourceCollectionInfo{PChannelName: "test_p"}, (*model.TargetCollectionInfo)(nil), api.TargetAPI(nil), &model.HandlerOpts{Factory: factory}) + _, err := newReplicateChannelHandler(&model.SourceCollectionInfo{PChannelName: "test_p"}, (*model.TargetCollectionInfo)(nil), api.TargetAPI(nil), nil, &model.HandlerOpts{Factory: factory}) assert.Error(t, err) }) @@ -246,7 +268,7 @@ func TestReplicateChannelHandler(t *testing.T) { { stream.EXPECT().Close().Return().Once() stream.EXPECT().AsConsumer(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(errors.New("mock error")).Once() - _, err := newReplicateChannelHandler(&model.SourceCollectionInfo{PChannelName: "test_p"}, (*model.TargetCollectionInfo)(nil), api.TargetAPI(nil), &model.HandlerOpts{Factory: factory}) + _, err := newReplicateChannelHandler(&model.SourceCollectionInfo{PChannelName: "test_p"}, (*model.TargetCollectionInfo)(nil), api.TargetAPI(nil), nil, &model.HandlerOpts{Factory: factory}) assert.Error(t, err) } @@ -254,7 +276,7 @@ func TestReplicateChannelHandler(t *testing.T) { stream.EXPECT().AsConsumer(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil).Once() stream.EXPECT().Close().Return().Once() stream.EXPECT().Seek(mock.Anything, mock.Anything).Return(errors.New("mock error")).Once() - _, err := newReplicateChannelHandler(&model.SourceCollectionInfo{PChannelName: "test_p", SeekPosition: &msgstream.MsgPosition{ChannelName: "test_p", MsgID: []byte("test")}}, (*model.TargetCollectionInfo)(nil), api.TargetAPI(nil), &model.HandlerOpts{Factory: factory}) + _, err := newReplicateChannelHandler(&model.SourceCollectionInfo{PChannelName: "test_p", SeekPosition: &msgstream.MsgPosition{ChannelName: "test_p", MsgID: []byte("test")}}, (*model.TargetCollectionInfo)(nil), api.TargetAPI(nil), nil, &model.HandlerOpts{Factory: factory}) assert.Error(t, err) } }) @@ -272,7 +294,7 @@ func TestReplicateChannelHandler(t *testing.T) { stream.EXPECT().Close().Return().Once() stream.EXPECT().Seek(mock.Anything, mock.Anything).Return(nil).Once() stream.EXPECT().Chan().Return(streamChan).Once() - handler, err := newReplicateChannelHandler(&model.SourceCollectionInfo{PChannelName: "test_p", SeekPosition: &msgstream.MsgPosition{ChannelName: "test_p", MsgID: []byte("test")}}, &model.TargetCollectionInfo{PChannel: "test_p"}, api.TargetAPI(nil), &model.HandlerOpts{Factory: factory}) + handler, err := newReplicateChannelHandler(&model.SourceCollectionInfo{PChannelName: "test_p", SeekPosition: &msgstream.MsgPosition{ChannelName: "test_p", MsgID: []byte("test")}}, &model.TargetCollectionInfo{PChannel: "test_p"}, api.TargetAPI(nil), nil, &model.HandlerOpts{Factory: factory}) assert.NoError(t, err) time.Sleep(100 * time.Microsecond) handler.Close() @@ -300,7 +322,7 @@ func TestReplicateChannelHandler(t *testing.T) { }, &model.TargetCollectionInfo{ PChannel: "test_p", CollectionName: "foo", - }, targetClient, &model.HandlerOpts{Factory: factory}) + }, targetClient, nil, &model.HandlerOpts{Factory: factory}) assert.NoError(t, err) time.Sleep(100 * time.Millisecond) assert.True(t, handler.containCollection("foo")) @@ -335,7 +357,7 @@ func TestReplicateChannelHandler(t *testing.T) { apiEventChan := make(chan *api.ReplicateAPIEvent) handler, err := func() (*replicateChannelHandler, error) { var _ chan<- *api.ReplicateAPIEvent = apiEventChan - return newReplicateChannelHandler(&model.SourceCollectionInfo{CollectionID: 1, PChannelName: "test_p", SeekPosition: &msgstream.MsgPosition{ChannelName: "test_p", MsgID: []byte("test")}}, &model.TargetCollectionInfo{CollectionID: 100, CollectionName: "test", PChannel: "test_p"}, api.TargetAPI(targetClient), &model.HandlerOpts{Factory: factory}) + return newReplicateChannelHandler(&model.SourceCollectionInfo{CollectionID: 1, PChannelName: "test_p", SeekPosition: &msgstream.MsgPosition{ChannelName: "test_p", MsgID: []byte("test")}}, &model.TargetCollectionInfo{CollectionID: 100, CollectionName: "test", PChannel: "test_p"}, api.TargetAPI(targetClient), nil, &model.HandlerOpts{Factory: factory}) }() assert.NoError(t, err) time.Sleep(100 * time.Millisecond) @@ -350,7 +372,7 @@ func TestReplicateChannelHandler(t *testing.T) { }, }) }() - handler.AddPartitionInfo(&pb.CollectionInfo{ + _ = handler.AddPartitionInfo(&pb.CollectionInfo{ ID: 2, Schema: &schemapb.CollectionSchema{ Name: "test2", @@ -363,7 +385,7 @@ func TestReplicateChannelHandler(t *testing.T) { handler.RemovePartitionInfo(2, "p2", 10002) assert.False(t, handler.IsEmpty()) - assert.NotNil(t, handler.Chan()) + assert.NotNil(t, handler.msgPackChan) // test updateTargetPartitionInfo targetClient.EXPECT().GetPartitionInfo(mock.Anything, mock.Anything).Return(nil, errors.New("mock error")).Once() @@ -417,7 +439,7 @@ func TestReplicateChannelHandler(t *testing.T) { PartitionBarrierChan: map[int64]chan<- uint64{ 1021: partitionBarrierChan, }, - }, targetClient, &model.HandlerOpts{Factory: factory}) + }, targetClient, nil, &model.HandlerOpts{Factory: factory}) assert.NoError(t, err) done := make(chan struct{}) diff --git a/server/cdc_impl.go b/server/cdc_impl.go index 6030125b..cfbdeb94 100644 --- a/server/cdc_impl.go +++ b/server/cdc_impl.go @@ -157,7 +157,8 @@ func (e *MetaCDC) ReloadTask() { e.cdcTasks.Unlock() if err := e.startInternal(taskInfo, taskInfo.State == meta.TaskStateRunning); err != nil { - log.Panic("fail to start the task", zap.Any("task_info", taskInfo), zap.Error(err)) + log.Warn("fail to start the task", zap.Any("task_info", taskInfo), zap.Error(err)) + _ = e.pauseTaskWithReason(taskInfo.TaskID, "fail to start task, err: "+err.Error(), []meta.TaskState{}) } } } @@ -375,11 +376,11 @@ func (e *MetaCDC) getUUID() string { } func (e *MetaCDC) startInternal(info *meta.TaskInfo, ignoreUpdateState bool) error { + taskLog := log.With(zap.String("task_id", info.TaskID)) milvusConnectParam := info.MilvusConnectParam milvusAddress := fmt.Sprintf("%s:%d", milvusConnectParam.Host, milvusConnectParam.Port) e.replicateEntityMap.RLock() replicateEntity, ok := e.replicateEntityMap.data[milvusAddress] - log.Info("ok", zap.Any("ok", ok)) e.replicateEntityMap.RUnlock() newReplicateEntity := func() (*ReplicateEntity, error) { @@ -393,7 +394,7 @@ func (e *MetaCDC) startInternal(info *meta.TaskInfo, ignoreUpdateState bool) err }) cancelFunc() if err != nil { - log.Warn("fail to new target", zap.String("address", milvusAddress), zap.Error(err)) + taskLog.Warn("fail to new target", zap.String("address", milvusAddress), zap.Error(err)) return nil, servererror.NewClientError("fail to connect target milvus server") } // TODO improve it @@ -403,7 +404,7 @@ func (e *MetaCDC) startInternal(info *meta.TaskInfo, ignoreUpdateState bool) err Kafka: e.config.SourceConfig.Kafka, }, e.mqFactoryCreator, milvusClient, bufferSize) if err != nil { - log.Warn("fail to create replicate channel manager", zap.Error(err)) + taskLog.Warn("fail to create replicate channel manager", zap.Error(err)) return nil, servererror.NewClientError("fail to create replicate channel manager") } targetConfig := milvusConnectParam @@ -414,14 +415,14 @@ func (e *MetaCDC) startInternal(info *meta.TaskInfo, ignoreUpdateState bool) err cdcwriter.IgnorePartitionOption(targetConfig.IgnorePartition), cdcwriter.ConnectTimeoutOption(targetConfig.ConnectTimeout)) if err != nil { - log.Warn("fail to new the data handler", zap.Error(err)) + taskLog.Warn("fail to new the data handler", zap.Error(err)) return nil, servererror.NewClientError("fail to new the data handler, task_id: ") } writerObj := cdcwriter.NewChannelWriter(dataHandler, bufferSize) sourceConfig := e.config.SourceConfig metaOp, err := cdcreader.NewEtcdOp(sourceConfig.EtcdAddress, sourceConfig.EtcdRootPath, sourceConfig.EtcdMetaSubPath, sourceConfig.DefaultPartitionName) if err != nil { - log.Warn("fail to new the meta op", zap.Error(err)) + taskLog.Warn("fail to new the meta op", zap.Error(err)) return nil, servererror.NewClientError("fail to new the meta op") } @@ -440,13 +441,23 @@ func (e *MetaCDC) startInternal(info *meta.TaskInfo, ignoreUpdateState bool) err for { replicateAPIEvent, ok := <-entity.channelManager.GetEventChan() if !ok { - log.Warn("the replicate api event channel has closed") + taskLog.Warn("the replicate api event channel has closed") + return + } + if !e.isRunningTask(info.TaskID) { + taskLog.Warn("not running task", zap.Any("event", replicateAPIEvent)) + return + } + if replicateAPIEvent.EventType == api.ReplicateError { + taskLog.Warn("receive the error event", zap.Any("event", replicateAPIEvent)) + _ = e.pauseTaskWithReason(info.TaskID, "fail to read the replicate event", []meta.TaskState{}) return } err := entity.writerObj.HandleReplicateAPIEvent(context.Background(), replicateAPIEvent) if err != nil { - // TODO - log.Panic("fail to handle the replicate api event", zap.Error(err)) + taskLog.Warn("fail to handle replicate event", zap.Any("event", replicateAPIEvent), zap.Error(err)) + _ = e.pauseTaskWithReason(info.TaskID, "fail to handle the replicate event, err: "+err.Error(), []meta.TaskState{}) + return } } }() @@ -455,24 +466,41 @@ func (e *MetaCDC) startInternal(info *meta.TaskInfo, ignoreUpdateState bool) err for { // TODO how to close them channelName, ok := <-entity.channelManager.GetChannelChan() - log.Info("start to replicate channel", zap.String("channel", channelName)) + taskLog.Info("start to replicate channel", zap.String("channel", channelName)) if !ok { - log.Warn("the channel name channel has closed") + taskLog.Warn("the channel name channel has closed") + return + } + if !e.isRunningTask(info.TaskID) { + taskLog.Warn("not running task") return } go func(c string) { for { - msgPack, ok := <-entity.channelManager.GetMsgChan(c) + msgChan := entity.channelManager.GetMsgChan(c) + if msgChan == nil { + log.Warn("not found the message channel", zap.String("channel", c)) + return + } + msgPack, ok := <-msgChan if !ok { - log.Warn("the data channel has closed") + taskLog.Warn("the data channel has closed") + return + } + if !e.isRunningTask(info.TaskID) { + taskLog.Warn("not running task", zap.Any("pack", msgPack)) + return + } + if msgPack == nil { + log.Warn("the message pack is nil, the task may be stopping") return } pChannel := msgPack.EndPositions[0].GetChannelName() position, targetPosition, err := entity.writerObj.HandleReplicateMessage(context.Background(), pChannel, msgPack) if err != nil { - // TODO - log.Panic("fail to handle the replicate message", zap.Error(err)) - continue + taskLog.Warn("fail to handle the replicate message", zap.Any("pack", msgPack), zap.Error(err)) + _ = e.pauseTaskWithReason(info.TaskID, "fail to handle replicate message, err:"+err.Error(), []meta.TaskState{}) + return } msgTime, _ := tsoutil.ParseHybridTs(msgPack.EndTs) metaPosition := &meta.PositionInfo{ @@ -494,8 +522,13 @@ func (e *MetaCDC) startInternal(info *meta.TaskInfo, ignoreUpdateState bool) err }, } if position != nil { - writeCallback.UpdateTaskCollectionPosition(TmpCollectionID, TmpCollectionName, c, + err = writeCallback.UpdateTaskCollectionPosition(TmpCollectionID, TmpCollectionName, c, metaPosition, metaOpPosition, metaTargetPosition) + if err != nil { + log.Warn("fail to update the collection position", zap.Any("pack", msgPack), zap.Error(err)) + _ = e.pauseTaskWithReason(info.TaskID, "fail to update task position, err:"+err.Error(), []meta.TaskState{}) + return + } } } }(channelName) @@ -515,16 +548,16 @@ func (e *MetaCDC) startInternal(info *meta.TaskInfo, ignoreUpdateState bool) err ctx := context.Background() taskPositions, err := e.metaStoreFactory.GetTaskCollectionPositionMetaStore(ctx).Get(ctx, &meta.TaskCollectionPosition{TaskID: info.TaskID}, nil) if err != nil { - log.Warn("fail to get the task collection position", zap.Error(err)) + taskLog.Warn("fail to get the task collection position", zap.Error(err)) return servererror.NewServerError(errors.WithMessage(err, "fail to get the task collection position")) } if len(taskPositions) > 1 { - log.Warn("the task collection position is invalid", zap.Any("task_id", info.TaskID)) + taskLog.Warn("the task collection position is invalid", zap.Any("task_id", info.TaskID)) return servererror.NewServerError(errors.New("the task collection position is invalid")) } channelSeekPosition := make(map[string]*msgpb.MsgPosition) if len(taskPositions) == 1 { - log.Info("task seek position", zap.Any("position", taskPositions[0].Positions)) + taskLog.Info("task seek position", zap.Any("position", taskPositions[0].Positions)) for _, p := range taskPositions[0].Positions { dataPair := p.DataPair channelSeekPosition[dataPair.GetKey()] = &msgpb.MsgPosition{ @@ -535,18 +568,30 @@ func (e *MetaCDC) startInternal(info *meta.TaskInfo, ignoreUpdateState bool) err } collectionReader, err := cdcreader.NewCollectionReader(info.TaskID, replicateEntity.channelManager, replicateEntity.metaOp, channelSeekPosition, GetShouldReadFunc(info)) if err != nil { - log.Warn("fail to new the collection reader", zap.Error(err)) + taskLog.Warn("fail to new the collection reader", zap.Error(err)) return servererror.NewServerError(errors.WithMessage(err, "fail to new the collection reader")) } + go func() { + err := <-collectionReader.ErrorChan() + if err == nil { + return + } + log.Warn("fail to read the message", zap.Error(err)) + _ = e.pauseTaskWithReason(info.TaskID, "fail to read the message, err:"+err.Error(), []meta.TaskState{}) + }() channelReader, err := cdcreader.NewChannelReader(info.RPCRequestChannelInfo.Name, info.RPCRequestChannelInfo.Position, config.MQConfig{ Pulsar: e.config.SourceConfig.Pulsar, Kafka: e.config.SourceConfig.Kafka, }, func(pack *msgstream.MsgPack) bool { + if !e.isRunningTask(info.TaskID) { + taskLog.Warn("not running task", zap.Any("pack", pack)) + return false + } positionBytes, err := replicateEntity.writerObj.HandleOpMessagePack(ctx, pack) if err != nil { - // TODO - log.Panic("fail to handle the op message pack", zap.Error(err)) + taskLog.Warn("fail to handle the op message pack", zap.Error(err)) + _ = e.pauseTaskWithReason(info.TaskID, "fail to handle the op message pack, err:"+err.Error(), []meta.TaskState{}) return false } msgTime, _ := tsoutil.ParseHybridTs(pack.EndTs) @@ -559,12 +604,17 @@ func (e *MetaCDC) startInternal(info *meta.TaskInfo, ignoreUpdateState bool) err }, } writeCallback := NewWriteCallback(e.metaStoreFactory, e.rootPath, info.TaskID) - writeCallback.UpdateTaskCollectionPosition(TmpCollectionID, TmpCollectionName, channelName, + err = writeCallback.UpdateTaskCollectionPosition(TmpCollectionID, TmpCollectionName, channelName, metaPosition, metaPosition, nil) + if err != nil { + log.Warn("fail to update the collection position", zap.Any("pack", pack), zap.Error(err)) + _ = e.pauseTaskWithReason(info.TaskID, "fail to update task position, err:"+err.Error(), []meta.TaskState{}) + return false + } return true }, e.mqFactoryCreator) if err != nil { - log.Warn("fail to new the channel reader", zap.Error(err)) + taskLog.Warn("fail to new the channel reader", zap.Error(err)) return servererror.NewServerError(errors.WithMessage(err, "fail to new the channel reader")) } readCtx, cancelReadFunc := context.WithCancel(context.Background()) @@ -578,21 +628,63 @@ func (e *MetaCDC) startInternal(info *meta.TaskInfo, ignoreUpdateState bool) err e.replicateEntityMap.Unlock() if !ignoreUpdateState { - err = store.UpdateTaskState( - e.metaStoreFactory.GetTaskInfoMetaStore(ctx), - info.TaskID, - meta.TaskStateRunning, - []meta.TaskState{meta.TaskStateInitial, meta.TaskStatePaused}) + err = store.UpdateTaskState(e.metaStoreFactory.GetTaskInfoMetaStore(ctx), info.TaskID, meta.TaskStateRunning, []meta.TaskState{meta.TaskStateInitial, meta.TaskStatePaused}, "") if err != nil { - log.Warn("fail to update the task meta", zap.Error(err)) + taskLog.Warn("fail to update the task meta", zap.Error(err)) return servererror.NewServerError(errors.WithMessage(err, "fail to update the task meta, task_id: "+info.TaskID)) } } + e.cdcTasks.Lock() + info.State = meta.TaskStateRunning + info.Reason = "" + e.cdcTasks.Unlock() + collectionReader.StartRead(readCtx) channelReader.StartRead(readCtx) return nil } +func (e *MetaCDC) isRunningTask(taskID string) bool { + e.cdcTasks.RLock() + defer e.cdcTasks.RUnlock() + task, ok := e.cdcTasks.data[taskID] + if !ok { + return false + } + return task.State == meta.TaskStateRunning +} + +func (e *MetaCDC) pauseTaskWithReason(taskID, reason string, currentStates []meta.TaskState) error { + err := store.UpdateTaskState( + e.metaStoreFactory.GetTaskInfoMetaStore(context.Background()), + taskID, + meta.TaskStatePaused, + currentStates, + reason) + if err != nil { + log.Warn("fail to update task reason", zap.String("task_id", taskID), zap.String("reason", reason)) + return err + } + e.cdcTasks.Lock() + cdcTask := e.cdcTasks.data[taskID] + if cdcTask == nil { + e.cdcTasks.Unlock() + return nil + } + cdcTask.State = meta.TaskStatePaused + cdcTask.Reason = reason + e.cdcTasks.Unlock() + + milvusAddress := GetMilvusAddress(cdcTask.MilvusConnectParam) + e.replicateEntityMap.Lock() + if replicateEntity, ok := e.replicateEntityMap.data[milvusAddress]; ok { + replicateEntity.quitFunc() + } + delete(e.replicateEntityMap.data, milvusAddress) + e.replicateEntityMap.Unlock() + return nil +} + func (e *MetaCDC) Delete(req *request.DeleteRequest) (*request.DeleteResponse, error) { e.cdcTasks.RLock() _, ok := e.cdcTasks.data[req.TaskID] @@ -633,27 +725,16 @@ func (e *MetaCDC) Delete(req *request.DeleteRequest) (*request.DeleteResponse, e func (e *MetaCDC) Pause(req *request.PauseRequest) (*request.PauseResponse, error) { e.cdcTasks.RLock() - cdcTask, ok := e.cdcTasks.data[req.TaskID] + _, ok := e.cdcTasks.data[req.TaskID] e.cdcTasks.RUnlock() if !ok { return nil, servererror.NewClientError("not found the task, task_id: " + req.TaskID) } - err := store.UpdateTaskState( - e.metaStoreFactory.GetTaskInfoMetaStore(context.Background()), - req.TaskID, - meta.TaskStatePaused, - []meta.TaskState{meta.TaskStateRunning}) + err := e.pauseTaskWithReason(req.TaskID, "manually pause through http interface", []meta.TaskState{meta.TaskStateRunning}) if err != nil { - return nil, servererror.NewServerError(errors.WithMessage(err, "fail to update the task meta, task_id: "+req.TaskID)) + return nil, servererror.NewServerError(errors.WithMessage(err, "fail to update the task state, task_id: "+req.TaskID)) } - milvusAddress := fmt.Sprintf("%s:%d", cdcTask.MilvusConnectParam.Host, cdcTask.MilvusConnectParam.Port) - e.replicateEntityMap.Lock() - if replicateEntity, ok := e.replicateEntityMap.data[milvusAddress]; ok { - replicateEntity.quitFunc() - } - delete(e.replicateEntityMap.data, milvusAddress) - e.replicateEntityMap.Unlock() return &request.PauseResponse{}, err } @@ -732,6 +813,10 @@ func (e *MetaCDC) List(req *request.ListRequest) (*request.ListResponse, error) }, nil } +func GetMilvusAddress(param model.MilvusConnectParam) string { + return fmt.Sprintf("%s:%d", param.Host, param.Port) +} + func GetShouldReadFunc(taskInfo *meta.TaskInfo) cdcreader.ShouldReadFunc { isAll := taskInfo.CollectionInfos[0].Name == cdcreader.AllCollection return func(collectionInfo *pb.CollectionInfo) bool { diff --git a/server/cdc_impl_test.go b/server/cdc_impl_test.go index 12d399b7..a6364d51 100644 --- a/server/cdc_impl_test.go +++ b/server/cdc_impl_test.go @@ -38,6 +38,7 @@ import ( "github.com/zilliztech/milvus-cdc/core/config" coremocks "github.com/zilliztech/milvus-cdc/core/mocks" "github.com/zilliztech/milvus-cdc/core/pb" + "github.com/zilliztech/milvus-cdc/core/util" "github.com/zilliztech/milvus-cdc/server/mocks" "github.com/zilliztech/milvus-cdc/server/model" "github.com/zilliztech/milvus-cdc/server/model/meta" @@ -203,34 +204,35 @@ func TestReload(t *testing.T) { EnableReverse: false, } metaCDC.metaStoreFactory = factory - assert.Panics(t, func() { - factory.EXPECT().GetTaskInfoMetaStore(mock.Anything).Return(store).Once() - factory.EXPECT().GetTaskCollectionPositionMetaStore(mock.Anything).Return(positionStore).Once() - store.EXPECT().Get(mock.Anything, mock.Anything, mock.Anything).Return([]*meta.TaskInfo{ - { - TaskID: "1234", - State: meta.TaskStateRunning, - MilvusConnectParam: model.MilvusConnectParam{ - Host: "127.0.0.1", - Port: 19530, - }, - CollectionInfos: []model.CollectionInfo{ - { - Name: "foo", - }, + factory.EXPECT().GetTaskInfoMetaStore(mock.Anything).Return(store) + factory.EXPECT().GetTaskCollectionPositionMetaStore(mock.Anything).Return(positionStore).Once() + store.EXPECT().Get(mock.Anything, mock.Anything, mock.Anything).Return([]*meta.TaskInfo{ + { + TaskID: "1234", + State: meta.TaskStateRunning, + MilvusConnectParam: model.MilvusConnectParam{ + Host: "127.0.0.1", + Port: 19530, + }, + CollectionInfos: []model.CollectionInfo{ + { + Name: "foo", }, }, - }, nil).Once() - positionStore.EXPECT().Get(mock.Anything, mock.Anything, mock.Anything).Return(nil, errors.New("test")).Once() + }, + }, nil).Once() + positionStore.EXPECT().Get(mock.Anything, mock.Anything, mock.Anything).Return(nil, errors.New("test")).Once() - metaCDC.replicateEntityMap.Lock() - metaCDC.replicateEntityMap.data = map[string]*ReplicateEntity{ - "127.0.0.1:19530": {}, - } - metaCDC.replicateEntityMap.Unlock() + metaCDC.replicateEntityMap.Lock() + metaCDC.replicateEntityMap.data = map[string]*ReplicateEntity{ + "127.0.0.1:19530": {}, + } + metaCDC.replicateEntityMap.Unlock() - metaCDC.ReloadTask() - }) + // get error when pause task + store.EXPECT().Get(mock.Anything, mock.Anything, mock.Anything).Return(nil, errors.New("test")).Once() + + metaCDC.ReloadTask() }) } @@ -1061,3 +1063,79 @@ func TestDelete(t *testing.T) { assert.NoError(t, err) }) } + +func TestIsRunningTask(t *testing.T) { + m := &MetaCDC{} + initMetaCDCMap(m) + + assert.False(t, m.isRunningTask("task1")) + + m.cdcTasks.Lock() + m.cdcTasks.data["task2"] = &meta.TaskInfo{ + State: meta.TaskStateRunning, + } + m.cdcTasks.data["task3"] = &meta.TaskInfo{ + State: meta.TaskStatePaused, + } + m.cdcTasks.Unlock() + assert.True(t, m.isRunningTask("task2")) + assert.False(t, m.isRunningTask("task3")) +} + +func TestPauseTask(t *testing.T) { + m := &MetaCDC{} + factory := mocks.NewMetaStoreFactory(t) + store := mocks.NewMetaStore[*meta.TaskInfo](t) + + initMetaCDCMap(m) + m.metaStoreFactory = factory + factory.EXPECT().GetTaskInfoMetaStore(mock.Anything).Return(store) + + t.Run("fail to update state", func(t *testing.T) { + store.EXPECT().Get(mock.Anything, mock.Anything, mock.Anything).Return(nil, errors.New("fail")).Once() + err := m.pauseTaskWithReason("task1", "foo", []meta.TaskState{}) + assert.Error(t, err) + }) + + t.Run("not found task", func(t *testing.T) { + store.EXPECT().Get(mock.Anything, mock.Anything, mock.Anything).Return([]*meta.TaskInfo{ + { + TaskID: "task1", + State: meta.TaskStateRunning, + }, + }, nil).Once() + store.EXPECT().Put(mock.Anything, mock.Anything, mock.Anything).Return(nil).Once() + err := m.pauseTaskWithReason("task1", "foo", []meta.TaskState{}) + assert.NoError(t, err) + }) + + t.Run("ok", func(t *testing.T) { + store.EXPECT().Get(mock.Anything, mock.Anything, mock.Anything).Return([]*meta.TaskInfo{ + { + TaskID: "task1", + State: meta.TaskStateRunning, + }, + }, nil).Once() + store.EXPECT().Put(mock.Anything, mock.Anything, mock.Anything).Return(nil).Once() + m.cdcTasks.Lock() + m.cdcTasks.data["task1"] = &meta.TaskInfo{ + MilvusConnectParam: model.MilvusConnectParam{ + Host: "127.0.0.1", + Port: 19530, + }, + } + m.cdcTasks.Unlock() + var isQuit util.Value[bool] + isQuit.Store(false) + m.replicateEntityMap.Lock() + m.replicateEntityMap.data["127.0.0.1:19530"] = &ReplicateEntity{ + quitFunc: func() { + isQuit.Store(true) + }, + } + m.replicateEntityMap.Unlock() + err := m.pauseTaskWithReason("task1", "foo", []meta.TaskState{}) + assert.NoError(t, err) + assert.True(t, isQuit.Load()) + }) +} diff --git a/server/configs/cdc.yaml b/server/configs/cdc.yaml index bfc71505..10021311 100644 --- a/server/configs/cdc.yaml +++ b/server/configs/cdc.yaml @@ -4,7 +4,7 @@ metaStoreConfig: storeType: etcd etcdEndpoints: - localhost:2379 - mysqlSourceUrl: root:root@tcp(127.0.0.1:3306)/milvus-cdc?charset=utf8 + mysqlSourceUrl: root:root@tcp(127.0.0.1:3306)/milvuscdc?charset=utf8 rootPath: cdc sourceConfig: etcdAddress: diff --git a/server/model/meta/task.go b/server/model/meta/task.go index a2b206d3..a9695bec 100644 --- a/server/model/meta/task.go +++ b/server/model/meta/task.go @@ -68,7 +68,7 @@ type TaskInfo struct { RPCRequestChannelInfo model.ChannelInfo ExcludeCollections []string // it's used for the `*` collection name State TaskState - FailedReason string + Reason string } func (t *TaskInfo) CollectionNames() []string { diff --git a/server/model/request/base.go b/server/model/request/base.go index 39a95041..ae1b769a 100644 --- a/server/model/request/base.go +++ b/server/model/request/base.go @@ -31,20 +31,24 @@ const ( List = "list" ) -//go:generate easytags $GOFILE json,mapstructure - type CDCRequest struct { RequestType string `json:"request_type" mapstructure:"request_type"` RequestData map[string]any `json:"request_data" mapstructure:"request_data"` } +type CDCResponse struct { + Code int `json:"code" mapstructure:"code"` + Message string `json:"message" mapstructure:"message"` + Data map[string]any `json:"data" mapstructure:"data"` +} + // Task some info can be showed about the task type Task struct { TaskID string `json:"task_id" mapstructure:"task_id"` MilvusConnectParam model.MilvusConnectParam `json:"milvus_connect_param" mapstructure:"milvus_connect_param"` CollectionInfos []model.CollectionInfo `json:"collection_infos" mapstructure:"collection_infos"` State string `json:"state" mapstructure:"state"` - LastFailReason string `json:"reason,omitempty" mapstructure:"reason"` + LastPauseReason string `json:"reason" mapstructure:"reason"` } func GetTask(taskInfo *meta.TaskInfo) Task { @@ -55,6 +59,6 @@ func GetTask(taskInfo *meta.TaskInfo) Task { MilvusConnectParam: taskInfo.MilvusConnectParam, CollectionInfos: taskInfo.CollectionInfos, State: taskInfo.State.String(), - LastFailReason: taskInfo.FailedReason, + LastPauseReason: taskInfo.Reason, } } diff --git a/server/server.go b/server/server.go index 23e629d1..4ad9c527 100644 --- a/server/server.go +++ b/server/server.go @@ -55,6 +55,7 @@ func (c *CDCServer) Run(config *CDCServerConfig) { func (c *CDCServer) getCDCHandler() http.Handler { return http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) { startTime := time.Now() + writer.Header().Set("Content-Type", "application/json") if request.Method != http.MethodPost { c.handleError(writer, "only support the POST method", http.StatusMethodNotAllowed, zap.String("method", request.Method)) @@ -79,8 +80,18 @@ func (c *CDCServer) getCDCHandler() http.Handler { response := c.handleRequest(cdcRequest, writer) if response != nil { - writer.Header().Set("Content-Type", "application/json") - _ = json.NewEncoder(writer).Encode(response) + var m map[string]interface{} + err = mapstructure.Decode(response, &m) + if err != nil { + log.Warn("fail to decode the response", zap.Any("resp", response), zap.Error(err)) + c.handleError(writer, err.Error(), http.StatusInternalServerError) + return + } + realResp := &modelrequest.CDCResponse{ + Code: 200, + Data: m, + } + _ = json.NewEncoder(writer).Encode(realResp) metrics.TaskRequestCountVec.WithLabelValues(cdcRequest.RequestType, metrics.SuccessStatusLabel).Inc() metrics.TaskRequestLatencyVec.WithLabelValues(cdcRequest.RequestType).Observe(float64(time.Since(startTime).Milliseconds())) } @@ -89,7 +100,11 @@ func (c *CDCServer) getCDCHandler() http.Handler { func (c *CDCServer) handleError(w http.ResponseWriter, error string, code int, fields ...zap.Field) { log.Warn(error, fields...) - http.Error(w, error, code) + errResp := &modelrequest.CDCResponse{ + Code: code, + Message: error, + } + _ = json.NewEncoder(w).Encode(errResp) } func (c *CDCServer) handleRequest(cdcRequest *modelrequest.CDCRequest, writer http.ResponseWriter) any { diff --git a/server/server_test.go b/server/server_test.go index aad3d6ee..5199c25a 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -24,9 +24,13 @@ import ( "testing" "github.com/cockroachdb/errors" + "github.com/milvus-io/milvus/pkg/log" + "github.com/mitchellh/mapstructure" "github.com/stretchr/testify/assert" + "go.uber.org/zap" cdcerror "github.com/zilliztech/milvus-cdc/server/error" + "github.com/zilliztech/milvus-cdc/server/model" "github.com/zilliztech/milvus-cdc/server/model/request" ) @@ -209,3 +213,46 @@ func TestCDCHandler(t *testing.T) { assert.Contains(t, string(responseWriter.resp), taskID) }) } + +func TestDecodeStruct(t *testing.T) { + t.Run("err", func(t *testing.T) { + buf := bytes.NewBufferString("") + errResp := &request.CDCResponse{ + Code: 500, + Message: "error msg", + } + _ = json.NewEncoder(buf).Encode(errResp) + log.Warn("err", zap.Any("resp", buf.String())) + }) + + t.Run("success", func(t *testing.T) { + buf := bytes.NewBufferString("") + var m map[string]interface{} + response := &request.ListResponse{ + Tasks: []request.Task{ + { + TaskID: "123", + MilvusConnectParam: model.MilvusConnectParam{ + Host: "localhost", + Port: 19530, + }, + CollectionInfos: []model.CollectionInfo{ + { + Name: "foo", + }, + }, + State: "Running", + LastPauseReason: "receive the pause request", + }, + }, + } + + _ = mapstructure.Decode(response, &m) + realResp := &request.CDCResponse{ + Code: 200, + Data: m, + } + _ = json.NewEncoder(buf).Encode(realResp) + log.Warn("err", zap.Any("resp", buf.String())) + }) +} diff --git a/server/store/meta_op.go b/server/store/meta_op.go index 88f9d0ad..a7c9b55e 100644 --- a/server/store/meta_op.go +++ b/server/store/meta_op.go @@ -58,7 +58,9 @@ func GetAllTaskInfo(taskInfoStore api.MetaStore[*meta.TaskInfo]) ([]*meta.TaskIn return taskInfos, nil } -func UpdateTaskState(taskInfoStore api.MetaStore[*meta.TaskInfo], taskID string, newState meta.TaskState, oldStates []meta.TaskState) error { +func UpdateTaskState(taskInfoStore api.MetaStore[*meta.TaskInfo], taskID string, + newState meta.TaskState, oldStates []meta.TaskState, reason string, +) error { ctx := context.Background() infos, err := taskInfoStore.Get(ctx, &meta.TaskInfo{TaskID: taskID}, nil) if err != nil { @@ -70,7 +72,7 @@ func UpdateTaskState(taskInfoStore api.MetaStore[*meta.TaskInfo], taskID string, return errors.Errorf("not found the task info with task id %s", taskID) } info := infos[0] - if !lo.Contains(oldStates, info.State) { + if len(oldStates) != 0 && !lo.Contains(oldStates, info.State) { oldStateStrs := lo.Map[meta.TaskState, string](oldStates, func(taskState meta.TaskState, i int) string { return taskState.String() }) @@ -79,6 +81,9 @@ func UpdateTaskState(taskInfoStore api.MetaStore[*meta.TaskInfo], taskID string, } oldState := info.State info.State = newState + if newState == meta.TaskStatePaused { + info.Reason = reason + } err = taskInfoStore.Put(ctx, info, nil) if err != nil { log.Warn("fail to put the task info to etcd", zap.String("task_id", taskID), zap.Error(err)) @@ -88,27 +93,6 @@ func UpdateTaskState(taskInfoStore api.MetaStore[*meta.TaskInfo], taskID string, return nil } -func UpdateTaskFailedReason(taskInfoStore api.MetaStore[*meta.TaskInfo], taskID string, reason string) error { - ctx := context.Background() - infos, err := taskInfoStore.Get(ctx, &meta.TaskInfo{TaskID: taskID}, nil) - if err != nil { - log.Warn("fail to get the task info", zap.String("task_id", taskID), zap.Error(err)) - return err - } - if len(infos) == 0 { - log.Warn("not found the task info", zap.String("task_id", taskID)) - return servererror.NewNotFoundError(taskID) - } - info := infos[0] - info.FailedReason = reason - err = taskInfoStore.Put(ctx, info, nil) - if err != nil { - log.Warn("fail to put the task info", zap.String("task_id", taskID), zap.Error(err)) - return err - } - return nil -} - func UpdateTaskCollectionPosition(taskPositionStore api.MetaStore[*meta.TaskCollectionPosition], taskID string, collectionID int64, collectionName string, pChannelName string, position, opPosition, targetPosition *meta.PositionInfo) error { ctx := context.Background() positions, err := taskPositionStore.Get(ctx, &meta.TaskCollectionPosition{TaskID: taskID, CollectionID: collectionID}, nil) diff --git a/server/store/meta_op_test.go b/server/store/meta_op_test.go index 43c622c7..7eb16a55 100644 --- a/server/store/meta_op_test.go +++ b/server/store/meta_op_test.go @@ -60,63 +60,40 @@ func TestUpdateState(t *testing.T) { t.Run("fail to get task info", func(t *testing.T) { store.EXPECT().Get(mock.Anything, mock.Anything, mock.Anything).Return(nil, errors.New("fail")).Once() - err := UpdateTaskState(store, "1234", meta.TaskStateInitial, []meta.TaskState{}) + err := UpdateTaskState(store, "1234", meta.TaskStateInitial, []meta.TaskState{}, "") assert.Error(t, err) }) t.Run("empty task info", func(t *testing.T) { store.EXPECT().Get(mock.Anything, mock.Anything, mock.Anything).Return([]*meta.TaskInfo{}, nil).Once() - err := UpdateTaskState(store, "1234", meta.TaskStateInitial, []meta.TaskState{}) + err := UpdateTaskState(store, "1234", meta.TaskStateInitial, []meta.TaskState{}, "") assert.Error(t, err) }) t.Run("unexpect state", func(t *testing.T) { store.EXPECT().Get(mock.Anything, mock.Anything, mock.Anything).Return([]*meta.TaskInfo{{TaskID: "1234", State: meta.TaskStateRunning}}, nil).Once() - err := UpdateTaskState(store, "1234", meta.TaskStateRunning, []meta.TaskState{meta.TaskStateInitial}) + err := UpdateTaskState(store, "1234", meta.TaskStateRunning, []meta.TaskState{meta.TaskStateInitial}, "") assert.Error(t, err) }) t.Run("fail to put task info", func(t *testing.T) { store.EXPECT().Get(mock.Anything, mock.Anything, mock.Anything).Return([]*meta.TaskInfo{{TaskID: "1234", State: meta.TaskStateInitial}}, nil).Once() store.EXPECT().Put(mock.Anything, mock.Anything, mock.Anything).Return(errors.New("fail")).Once() - err := UpdateTaskState(store, "1234", meta.TaskStateRunning, []meta.TaskState{meta.TaskStateInitial}) + err := UpdateTaskState(store, "1234", meta.TaskStateRunning, []meta.TaskState{meta.TaskStateInitial}, "") assert.Error(t, err) }) t.Run("success", func(t *testing.T) { store.EXPECT().Get(mock.Anything, mock.Anything, mock.Anything).Return([]*meta.TaskInfo{{TaskID: "1234", State: meta.TaskStateInitial}}, nil).Once() store.EXPECT().Put(mock.Anything, mock.Anything, mock.Anything).Return(nil).Once() - err := UpdateTaskState(store, "1234", meta.TaskStateRunning, []meta.TaskState{meta.TaskStateInitial}) + err := UpdateTaskState(store, "1234", meta.TaskStateRunning, []meta.TaskState{meta.TaskStateInitial}, "") assert.NoError(t, err) }) -} - -func TestUpdateFailReason(t *testing.T) { - store := mocks.NewMetaStore[*meta.TaskInfo](t) - - t.Run("fail to get task info", func(t *testing.T) { - store.EXPECT().Get(mock.Anything, mock.Anything, mock.Anything).Return(nil, errors.New("fail")).Once() - err := UpdateTaskFailedReason(store, "1234", "foo") - assert.Error(t, err) - }) - - t.Run("empty task info", func(t *testing.T) { - store.EXPECT().Get(mock.Anything, mock.Anything, mock.Anything).Return([]*meta.TaskInfo{}, nil).Once() - err := UpdateTaskFailedReason(store, "1234", "foo") - assert.Error(t, err) - }) - - t.Run("fail to put task info", func(t *testing.T) { - store.EXPECT().Get(mock.Anything, mock.Anything, mock.Anything).Return([]*meta.TaskInfo{{TaskID: "1234", State: meta.TaskStateInitial}}, nil).Once() - store.EXPECT().Put(mock.Anything, mock.Anything, mock.Anything).Return(errors.New("fail")).Once() - err := UpdateTaskFailedReason(store, "1234", "foo") - assert.Error(t, err) - }) - t.Run("success", func(t *testing.T) { + t.Run("success to pause", func(t *testing.T) { store.EXPECT().Get(mock.Anything, mock.Anything, mock.Anything).Return([]*meta.TaskInfo{{TaskID: "1234", State: meta.TaskStateInitial}}, nil).Once() store.EXPECT().Put(mock.Anything, mock.Anything, mock.Anything).Return(nil).Once() - err := UpdateTaskFailedReason(store, "1234", "foo") + err := UpdateTaskState(store, "1234", meta.TaskStatePaused, []meta.TaskState{}, "pause test") assert.NoError(t, err) }) } diff --git a/server/writer_callback.go b/server/writer_callback.go index 9328de16..183742bf 100644 --- a/server/writer_callback.go +++ b/server/writer_callback.go @@ -18,6 +18,7 @@ package server import ( "context" + "errors" "github.com/milvus-io/milvus/pkg/log" "go.uber.org/zap" @@ -45,9 +46,9 @@ func NewWriteCallback(factory api.MetaStoreFactory, rootPath string, taskID stri } } -func (w *WriteCallback) UpdateTaskCollectionPosition(collectionID int64, collectionName string, pChannelName string, position, opPosition, targetPosition *meta.PositionInfo) { +func (w *WriteCallback) UpdateTaskCollectionPosition(collectionID int64, collectionName string, pChannelName string, position, opPosition, targetPosition *meta.PositionInfo) error { if position == nil { - return + return errors.New("position is nil") } err := store.UpdateTaskCollectionPosition( w.metaStoreFactory.GetTaskCollectionPositionMetaStore(context.Background()), @@ -64,4 +65,5 @@ func (w *WriteCallback) UpdateTaskCollectionPosition(collectionID int64, collect zap.Error(err)) metrics.WriterFailCountVec.WithLabelValues(w.taskID, metrics.WriteFailOnUpdatePosition).Inc() } + return err } diff --git a/server/writer_callback_test.go b/server/writer_callback_test.go index 737e4edf..746a83c0 100644 --- a/server/writer_callback_test.go +++ b/server/writer_callback_test.go @@ -7,6 +7,7 @@ import ( "github.com/cockroachdb/errors" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus/pkg/log" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" "go.uber.org/zap" @@ -24,27 +25,34 @@ func TestWriterCallback(t *testing.T) { store := mocks.NewMetaStore[*meta.TaskCollectionPosition](t) callback := NewWriteCallback(factory, "test", "12345") + t.Run("empty position", func(t *testing.T) { + err := callback.UpdateTaskCollectionPosition(1, "test", "test", nil, nil, nil) + assert.Error(t, err) + }) + t.Run("fail", func(t *testing.T) { factory.EXPECT().GetTaskCollectionPositionMetaStore(mock.Anything).Return(store).Once() store.EXPECT().Get(mock.Anything, mock.Anything, mock.Anything).Return(nil, errors.New("test")).Once() - callback.UpdateTaskCollectionPosition(1, "test", "test", &meta.PositionInfo{ + err := callback.UpdateTaskCollectionPosition(1, "test", "test", &meta.PositionInfo{ Time: 1, DataPair: &commonpb.KeyDataPair{ Key: "test", Data: []byte("test"), }, }, nil, nil) + assert.Error(t, err) }) t.Run("success", func(t *testing.T) { factory.EXPECT().GetTaskCollectionPositionMetaStore(mock.Anything).Return(store).Once() store.EXPECT().Get(mock.Anything, mock.Anything, mock.Anything).Return([]*meta.TaskCollectionPosition{}, nil).Once() store.EXPECT().Put(mock.Anything, mock.Anything, mock.Anything).Return(nil).Once() - callback.UpdateTaskCollectionPosition(1, "test", "test", &meta.PositionInfo{ + err := callback.UpdateTaskCollectionPosition(1, "test", "test", &meta.PositionInfo{ Time: 1, DataPair: &commonpb.KeyDataPair{ Key: "test", Data: []byte("test"), }, }, nil, nil) + assert.NoError(t, err) }) }