Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

remove the task info param when start to replicate event #155

Merged
merged 1 commit into from
Nov 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions core/api/replicate_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ type ReplicateAPIEvent struct {
PartitionInfo *pb.PartitionInfo
ReplicateInfo *commonpb.ReplicateInfo
ReplicateParam ReplicateParam
TaskID string
Error error
}

Expand Down
4 changes: 3 additions & 1 deletion core/api/replicate_msg.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,16 @@ type ReplicateMsg struct {
CollectionName string
CollectionID int64
PChannelName string
TaskID string
MsgPack *msgstream.MsgPack
}

func GetReplicateMsg(pchannelName string, collectionName string, collectionID int64, msgPack *msgstream.MsgPack) *ReplicateMsg {
func GetReplicateMsg(pchannelName string, collectionName string, collectionID int64, msgPack *msgstream.MsgPack, taskID string) *ReplicateMsg {
return &ReplicateMsg{
CollectionName: collectionName,
CollectionID: collectionID,
PChannelName: pchannelName,
TaskID: taskID,
MsgPack: msgPack,
}
}
3 changes: 2 additions & 1 deletion core/reader/collection_reader.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ var _ api.Reader = (*CollectionReader)(nil)
type CollectionReader struct {
api.DefaultReader

id string
id string // which is task id
channelManager api.ChannelManager
metaOp api.MetaOp
channelSeekPositions map[int64]map[string]*msgpb.MsgPosition
Expand Down Expand Up @@ -92,6 +92,7 @@ func NewCollectionReader(id string,

func (reader *CollectionReader) StartRead(ctx context.Context) {
reader.startOnce.Do(func() {
ctx = util.GetCtxWithTaskID(ctx, reader.id)
reader.metaOp.SubscribeCollectionEvent(reader.id, func(info *pb.CollectionInfo) bool {
collectionLog := log.With(
zap.String("task_id", reader.id),
Expand Down
52 changes: 33 additions & 19 deletions core/reader/replicate_channel_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ type replicateChannelManager struct {
downstream string
}

func NewReplicateChannelManagerWithDispatchClient(
func NewReplicateChannelManager(
dispatchClient msgdispatcher.Client,
factory msgstream.Factory,
client api.TargetAPI,
Expand Down Expand Up @@ -281,6 +281,7 @@ func (r *replicateChannelManager) sendCreateCollectionEvent(ctx context.Context,
// TODO fubang should give a error when the collection shard num is more than the target dml channel num
r.apiEventChan <- &api.ReplicateAPIEvent{
EventType: api.ReplicateCreateCollection,
TaskID: util.GetTaskIDFromCtx(ctx),
CollectionInfo: info,
ReplicateInfo: &commonpb.ReplicateInfo{
IsReplicate: true,
Expand Down Expand Up @@ -332,6 +333,7 @@ func (r *replicateChannelManager) StartReadCollection(ctx context.Context, db *m
}
return nil
}
taskID := util.GetTaskIDFromCtx(ctx)
r.collectionLock.Lock()
if _, ok := r.replicateCollections[info.ID]; ok {
r.collectionLock.Unlock()
Expand All @@ -351,6 +353,7 @@ func (r *replicateChannelManager) StartReadCollection(ctx context.Context, db *m
MsgTimestamp: msgTs,
},
ReplicateParam: api.ReplicateParam{Database: targetInfo.DatabaseName},
TaskID: taskID,
}:
r.droppedCollections.Store(info.ID, struct{}{})
for _, name := range info.PhysicalChannelNames {
Expand All @@ -369,7 +372,7 @@ func (r *replicateChannelManager) StartReadCollection(ctx context.Context, db *m
err = ForeachChannel(info.VirtualChannelNames, targetInfo.VChannels, func(sourceVChannel, targetVChannel string) error {
sourcePChannel := funcutil.ToPhysicalChannel(sourceVChannel)
targetPChannel := funcutil.ToPhysicalChannel(targetVChannel)
channelHandler, err := r.startReadChannel(&model.SourceCollectionInfo{
channelHandler, err := r.startReadChannel(ctx, &model.SourceCollectionInfo{
PChannel: sourcePChannel,
VChannel: sourceVChannel,
CollectionID: info.ID,
Expand Down Expand Up @@ -445,8 +448,11 @@ func ForeachChannel(sourcePChannels, targetPChannels []string, f func(sourcePCha
func (r *replicateChannelManager) AddPartition(ctx context.Context, dbInfo *model.DatabaseInfo, collectionInfo *pb.CollectionInfo, partitionInfo *pb.PartitionInfo) error {
var handlers []*replicateChannelHandler
collectionID := collectionInfo.ID
taskID := util.GetTaskIDFromCtx(ctx)
partitionLog := log.With(zap.Int64("partition_id", partitionInfo.PartitionID), zap.Int64("collection_id", collectionID),
zap.String("collection_name", collectionInfo.Schema.Name), zap.String("partition_name", partitionInfo.PartitionName))
zap.String("collection_name", collectionInfo.Schema.Name), zap.String("partition_name", partitionInfo.PartitionName),
zap.String("task_id", taskID),
)
if dbInfo.Dropped {
partitionLog.Info("the database has been dropped when add partition")
return nil
Expand Down Expand Up @@ -515,6 +521,7 @@ func (r *replicateChannelManager) AddPartition(ctx context.Context, dbInfo *mode
MsgTimestamp: partitionInfo.PartitionCreatedTimestamp,
},
ReplicateParam: api.ReplicateParam{Database: dbInfo.Name},
TaskID: taskID,
}:
case <-ctx.Done():
partitionLog.Warn("context is done when adding partition")
Expand All @@ -535,6 +542,7 @@ func (r *replicateChannelManager) AddPartition(ctx context.Context, dbInfo *mode
MsgTimestamp: msgTs,
},
ReplicateParam: api.ReplicateParam{Database: dbInfo.Name},
TaskID: taskID,
}:
r.droppedPartitions.Store(partitionInfo.PartitionID, struct{}{})
for _, handler := range handlers {
Expand All @@ -554,7 +562,7 @@ func (r *replicateChannelManager) AddPartition(ctx context.Context, dbInfo *mode
r.replicatePartitions[collectionID][partitionInfo.PartitionID] = barrier.CloseChan
r.partitionLock.Unlock()
for _, handler := range handlers {
err = handler.AddPartitionInfo(collectionInfo, partitionInfo, barrier.BarrierChan)
err = handler.AddPartitionInfo(taskID, collectionInfo, partitionInfo, barrier.BarrierChan)
if err != nil {
return err
}
Expand Down Expand Up @@ -619,7 +627,7 @@ func (r *replicateChannelManager) GetChannelLatestMsgID(ctx context.Context, cha
// startReadChannel start read channel
// sourcePChannel: source milvus channel name, collectionID: source milvus collection id, startPosition: start position of the source milvus collection
// targetInfo: target collection info, it will be used to replace the message info in the source milvus channel
func (r *replicateChannelManager) startReadChannel(sourceInfo *model.SourceCollectionInfo, targetInfo *model.TargetCollectionInfo) (*replicateChannelHandler, error) {
func (r *replicateChannelManager) startReadChannel(ctx context.Context, sourceInfo *model.SourceCollectionInfo, targetInfo *model.TargetCollectionInfo) (*replicateChannelHandler, error) {
r.channelLock.Lock()
defer r.channelLock.Unlock()

Expand All @@ -631,16 +639,20 @@ func (r *replicateChannelManager) startReadChannel(sourceInfo *model.SourceColle
)
channelMappingKey := r.channelMapping.GetMapKey(sourceInfo.PChannel, targetInfo.PChannel)
channelMappingValue := r.channelMapping.GetMapValue(sourceInfo.PChannel, targetInfo.PChannel)
taskID := util.GetTaskIDFromCtx(ctx)

// TODO how to handle the seek position when the pchannel has been replicated
channelHandler, ok := r.channelHandlerMap[channelMappingKey]
if !ok {
var err error
channelHandler, err = initReplicateChannelHandler(r.getCtx(), sourceInfo, targetInfo, r.targetClient, r.metaOp, r.apiEventChan, &model.HandlerOpts{
MessageBufferSize: r.messageBufferSize,
TTInterval: r.ttInterval,
RetryOptions: r.retryOptions,
}, r.streamCreator, r.downstream, channelMappingKey == sourceInfo.PChannel)
}, r.streamCreator,
r.downstream,
channelMappingKey == sourceInfo.PChannel,
taskID,
)
if err != nil {
channelLog.Warn("init replicate channel handler failed", zap.Error(err))
return nil, err
Expand Down Expand Up @@ -678,7 +690,7 @@ func (r *replicateChannelManager) startReadChannel(sourceInfo *model.SourceColle
}
// the msg dispatch client maybe blocked, and has get the target channel,
// so we can use the goroutine and release the channelLock
go channelHandler.AddCollection(sourceInfo, targetInfo)
go channelHandler.AddCollection(taskID, sourceInfo, targetInfo)
return nil, nil
}

Expand Down Expand Up @@ -879,7 +891,7 @@ type replicateChannelHandler struct {
startReadChan chan struct{}
}

func (r *replicateChannelHandler) AddCollection(sourceInfo *model.SourceCollectionInfo, targetInfo *model.TargetCollectionInfo) {
func (r *replicateChannelHandler) AddCollection(taskID string, sourceInfo *model.SourceCollectionInfo, targetInfo *model.TargetCollectionInfo) {
select {
case <-r.replicateCtx.Done():
log.Warn("replicate channel handler closed")
Expand Down Expand Up @@ -916,7 +928,7 @@ func (r *replicateChannelHandler) AddCollection(sourceInfo *model.SourceCollecti
return
}

r.innerHandleReplicateMsg(false, api.GetReplicateMsg(sourceInfo.PChannel, targetInfo.CollectionName, collectionID, msgPack))
r.innerHandleReplicateMsg(false, api.GetReplicateMsg(sourceInfo.PChannel, targetInfo.CollectionName, collectionID, msgPack, taskID))
}
}
}()
Expand Down Expand Up @@ -964,7 +976,7 @@ func (r *replicateChannelHandler) AddCollection(sourceInfo *model.SourceCollecti
},
},
}
r.generatePackChan <- api.GetReplicateMsg(sourceInfo.PChannel, targetInfo.CollectionName, collectionID, generateMsgPack)
r.generatePackChan <- api.GetReplicateMsg(sourceInfo.PChannel, targetInfo.CollectionName, collectionID, generateMsgPack, taskID)
dropCollectionLog.Info("has generate msg for dropped collection")
return struct{}{}, nil
})
Expand Down Expand Up @@ -996,7 +1008,7 @@ func (r *replicateChannelHandler) RemoveCollection(collectionID int64) {
log.Info("remove collection from handler", zap.Int64("collection_id", collectionID))
}

func (r *replicateChannelHandler) AddPartitionInfo(collectionInfo *pb.CollectionInfo, partitionInfo *pb.PartitionInfo, barrierChan chan<- uint64) error {
func (r *replicateChannelHandler) AddPartitionInfo(taskID string, collectionInfo *pb.CollectionInfo, partitionInfo *pb.PartitionInfo, barrierChan chan<- uint64) error {
collectionID := collectionInfo.ID
partitionID := partitionInfo.PartitionID
collectionName := collectionInfo.Schema.Name
Expand Down Expand Up @@ -1071,7 +1083,7 @@ func (r *replicateChannelHandler) AddPartitionInfo(collectionInfo *pb.Collection
},
},
}
r.generatePackChan <- api.GetReplicateMsg(sourcePChannel, collectionName, collectionID, generateMsgPack)
r.generatePackChan <- api.GetReplicateMsg(sourcePChannel, collectionName, collectionID, generateMsgPack, taskID)
partitionLog.Info("has generate msg for dropped partition")
return struct{}{}, nil
})
Expand Down Expand Up @@ -1178,13 +1190,14 @@ func (r *replicateChannelHandler) getTSManagerChannelKey(channelName string) str

func (r *replicateChannelHandler) innerHandleReplicateMsg(forward bool, msg *api.ReplicateMsg) {
msgPack := msg.MsgPack
p := r.handlePack(forward, msgPack)
p := r.handlePack(forward, msgPack, msg.TaskID)
if p == api.EmptyMsgPack {
return
}
p.CollectionID = msg.CollectionID
p.CollectionName = msg.CollectionName
p.PChannelName = msg.PChannelName
p.TaskID = msg.TaskID
GetTSManager().SendTargetMsg(r.getTSManagerChannelKey(r.targetPChannel), p)
}

Expand Down Expand Up @@ -1323,7 +1336,7 @@ func isSupportedMsgType(msgType commonpb.MsgType) bool {
msgType == commonpb.MsgType_DropPartition
}

func (r *replicateChannelHandler) handlePack(forward bool, pack *msgstream.MsgPack) *api.ReplicateMsg {
func (r *replicateChannelHandler) handlePack(forward bool, pack *msgstream.MsgPack, taskID string) *api.ReplicateMsg {
sort.Slice(pack.Msgs, func(i, j int) bool {
return pack.Msgs[i].BeginTs() < pack.Msgs[j].BeginTs() ||
(pack.Msgs[i].BeginTs() == pack.Msgs[j].BeginTs() && pack.Msgs[i].Type() == commonpb.MsgType_Delete)
Expand Down Expand Up @@ -1591,7 +1604,7 @@ func (r *replicateChannelHandler) handlePack(forward bool, pack *msgstream.MsgPa
}

if forwardChannel != "" {
r.forwardMsgFunc(forwardChannel, api.GetReplicateMsg(streamPChannel, sourceCollectionName, sourceCollectionID, newPack))
r.forwardMsgFunc(forwardChannel, api.GetReplicateMsg(streamPChannel, sourceCollectionName, sourceCollectionID, newPack, taskID))
return api.EmptyMsgPack
}

Expand Down Expand Up @@ -1630,7 +1643,7 @@ func (r *replicateChannelHandler) handlePack(forward bool, pack *msgstream.MsgPa
resetLastTs := needTsMsg
needTsMsg = needTsMsg || len(newPack.Msgs) == 0
if !needTsMsg {
return api.GetReplicateMsg("", sourceCollectionName, sourceCollectionID, newPack)
return api.GetReplicateMsg("", sourceCollectionName, sourceCollectionID, newPack, "")
}
timeTickResult := &msgpb.TimeTickMsg{
Base: commonpbutil.NewMsgBase(
Expand All @@ -1655,7 +1668,7 @@ func (r *replicateChannelHandler) handlePack(forward bool, pack *msgstream.MsgPa
msgTime, _ := tsoutil.ParseHybridTs(generateTS)
TSMetricVec.WithLabelValues(r.targetPChannel).Set(float64(msgTime))
r.ttRateLog.Debug("time tick msg", zap.String("channel", r.targetPChannel), zap.Uint64("max_ts", generateTS))
return api.GetReplicateMsg("", sourceCollectionName, sourceCollectionID, newPack)
return api.GetReplicateMsg("", sourceCollectionName, sourceCollectionID, newPack, "")
}

func resetMsgPackTimestamp(pack *msgstream.MsgPack, newTimestamp uint64) bool {
Expand Down Expand Up @@ -1783,6 +1796,7 @@ func initReplicateChannelHandler(ctx context.Context,
opts *model.HandlerOpts,
streamCreator StreamCreator,
downstream string, sourceKey bool,
taskID string,
) (*replicateChannelHandler, error) {
err := streamCreator.CheckConnection(ctx, sourceInfo.VChannel, sourceInfo.SeekPosition)
if err != nil {
Expand Down Expand Up @@ -1816,6 +1830,6 @@ func initReplicateChannelHandler(ctx context.Context,
sourceKey: sourceKey,
startReadChan: make(chan struct{}),
}
go channelHandler.AddCollection(sourceInfo, targetInfo)
go channelHandler.AddCollection(taskID, sourceInfo, targetInfo)
return channelHandler, nil
}
20 changes: 10 additions & 10 deletions core/reader/replicate_channel_manager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,7 @@ func TestStartReadCollectionForMilvus(t *testing.T) {
t.Run("read channel", func(t *testing.T) {
{
// start read
handler, err := realManager.startReadChannel(&model.SourceCollectionInfo{
handler, err := realManager.startReadChannel(context.Background(), &model.SourceCollectionInfo{
PChannel: "test_read_channel",
VChannel: "test_read_channel_v0",
CollectionID: 11001,
Expand All @@ -294,7 +294,7 @@ func TestStartReadCollectionForMilvus(t *testing.T) {
handler.startReadChannel()
assert.Equal(t, "ttest_read_channel", <-realManager.GetChannelChan())

_, err = realManager.startReadChannel(&model.SourceCollectionInfo{
_, err = realManager.startReadChannel(context.Background(), &model.SourceCollectionInfo{
PChannel: "test_read_channel_2",
VChannel: "test_read_channel_2_v0",
CollectionID: 11002,
Expand Down Expand Up @@ -436,7 +436,7 @@ func TestStartReadCollectionForKafka(t *testing.T) {
t.Run("read channel", func(t *testing.T) {
{
// start read
handler, err := realManager.startReadChannel(&model.SourceCollectionInfo{
handler, err := realManager.startReadChannel(context.Background(), &model.SourceCollectionInfo{
PChannel: "kafka_test_read_channel",
VChannel: "kafka_test_read_channel_v0",
CollectionID: 11001,
Expand All @@ -458,7 +458,7 @@ func TestStartReadCollectionForKafka(t *testing.T) {
handler.startReadChannel()
assert.Equal(t, "kafka_ttest_read_channel", <-realManager.GetChannelChan())

_, err = realManager.startReadChannel(&model.SourceCollectionInfo{
_, err = realManager.startReadChannel(context.Background(), &model.SourceCollectionInfo{
PChannel: "kafka_test_read_channel_2",
VChannel: "kafka_test_read_channel_2_v0",
CollectionID: 11002,
Expand Down Expand Up @@ -565,7 +565,7 @@ func newReplicateChannelHandler(ctx context.Context,
factory: opts.Factory,
}

channelHandler, err := initReplicateChannelHandler(ctx, sourceInfo, targetInfo, targetClient, metaOp, apiEventChan, opts, creator, "milvus", true)
channelHandler, err := initReplicateChannelHandler(ctx, sourceInfo, targetInfo, targetClient, metaOp, apiEventChan, opts, creator, "milvus", true, "")
if err == nil {
channelHandler.addCollectionCnt = new(int)
channelHandler.addCollectionLock = &deadlock.RWMutex{}
Expand Down Expand Up @@ -654,7 +654,7 @@ func TestReplicateChannelHandler(t *testing.T) {
assert.True(t, handler.containCollection("foo"))
handler.Close()

handler.AddCollection(&model.SourceCollectionInfo{
handler.AddCollection("", &model.SourceCollectionInfo{
CollectionID: 1,
}, &model.TargetCollectionInfo{
CollectionName: "test",
Expand Down Expand Up @@ -698,7 +698,7 @@ func TestReplicateChannelHandler(t *testing.T) {

go func() {
time.Sleep(600 * time.Millisecond)
handler.AddCollection(&model.SourceCollectionInfo{
handler.AddCollection("", &model.SourceCollectionInfo{
CollectionID: 2,
}, &model.TargetCollectionInfo{
CollectionName: "test2",
Expand All @@ -708,7 +708,7 @@ func TestReplicateChannelHandler(t *testing.T) {
DroppedPartition: make(map[int64]struct{}),
})
}()
err = handler.AddPartitionInfo(&pb.CollectionInfo{
err = handler.AddPartitionInfo("", &pb.CollectionInfo{
ID: 2,
Schema: &schemapb.CollectionSchema{
Name: "test2",
Expand Down Expand Up @@ -740,7 +740,7 @@ func TestReplicateChannelHandler(t *testing.T) {
}, nil).Once()
assert.EqualValues(t, 0, handler.updateTargetPartitionInfo(3, "col3", "p2"))
assert.EqualValues(t, 0, handler.updateTargetPartitionInfo(3, "col3", "p2"))
handler.AddCollection(&model.SourceCollectionInfo{
handler.AddCollection("", &model.SourceCollectionInfo{
CollectionID: 3,
}, &model.TargetCollectionInfo{
CollectionName: "col3",
Expand Down Expand Up @@ -798,7 +798,7 @@ func TestReplicateChannelHandler(t *testing.T) {
}
GetTSManager().InitTSInfo(replicateID, handler.targetPChannel, 100*time.Millisecond, math.MaxUint64, 10)

err = handler.AddPartitionInfo(&pb.CollectionInfo{
err = handler.AddPartitionInfo("", &pb.CollectionInfo{
ID: 1,
Schema: &schemapb.CollectionSchema{
Name: "test",
Expand Down
Loading
Loading