From 61399e53b92a0a63d402fa802d9fd0c8fcb90715 Mon Sep 17 00:00:00 2001 From: Yu Xia Date: Mon, 24 May 2021 12:57:30 -0700 Subject: [PATCH] Fix update domain replication ack level (#4212) * Fix update domain replication ack level --- common/domain/replication_queue.go | 50 +++++++------------ .../persistence/cassandra/cassandraQueue.go | 2 +- .../persistence-tests/queuePersistenceTest.go | 3 ++ common/persistence/sql/sqlQueue.go | 5 +- service/frontend/adminHandler.go | 12 ++--- .../domain_replication_processor.go | 4 ++ .../domain_replication_processor_test.go | 3 ++ service/worker/replicator/replicator.go | 1 + 8 files changed, 35 insertions(+), 45 deletions(-) diff --git a/common/domain/replication_queue.go b/common/domain/replication_queue.go index c3440fba01f..b153e4bf2ab 100644 --- a/common/domain/replication_queue.go +++ b/common/domain/replication_queue.go @@ -57,28 +57,25 @@ func NewReplicationQueue( logger log.Logger, ) ReplicationQueue { return &replicationQueueImpl{ - queue: queue, - clusterName: clusterName, - metricsClient: metricsClient, - logger: logger, - encoder: codec.NewThriftRWEncoder(), - ackNotificationChan: make(chan bool), - done: make(chan bool), - status: common.DaemonStatusInitialized, + queue: queue, + clusterName: clusterName, + metricsClient: metricsClient, + logger: logger, + encoder: codec.NewThriftRWEncoder(), + done: make(chan bool), + status: common.DaemonStatusInitialized, } } type ( replicationQueueImpl struct { - queue persistence.QueueManager - clusterName string - metricsClient metrics.Client - logger log.Logger - encoder codec.BinaryEncoder - ackLevelUpdated bool - ackNotificationChan chan bool - done chan bool - status int32 + queue persistence.QueueManager + clusterName string + metricsClient metrics.Client + logger log.Logger + encoder codec.BinaryEncoder + done chan bool + status int32 } // ReplicationQueue is used to publish and list domain replication tasks @@ -177,16 +174,10 @@ func (q *replicationQueueImpl) UpdateAckLevel( clusterName string, ) error { - err := q.queue.UpdateAckLevel(ctx, lastProcessedMessageID, clusterName) - if err != nil { + if err := q.queue.UpdateAckLevel(ctx, lastProcessedMessageID, clusterName); err != nil { return fmt.Errorf("failed to update ack level: %v", err) } - select { - case q.ackNotificationChan <- true: - default: - } - return nil } @@ -305,16 +296,9 @@ func (q *replicationQueueImpl) purgeProcessor() { case <-q.done: return case <-ticker.C: - if q.ackLevelUpdated { - err := q.purgeAckedMessages() - if err != nil { - q.logger.Warn("Failed to purge acked domain replication messages.", tag.Error(err)) - } else { - q.ackLevelUpdated = false - } + if err := q.purgeAckedMessages(); err != nil { + q.logger.Warn("Failed to purge acked domain replication messages.", tag.Error(err)) } - case <-q.ackNotificationChan: - q.ackLevelUpdated = true } } } diff --git a/common/persistence/cassandra/cassandraQueue.go b/common/persistence/cassandra/cassandraQueue.go index 2897d4445ca..aefed53e5e7 100644 --- a/common/persistence/cassandra/cassandraQueue.go +++ b/common/persistence/cassandra/cassandraQueue.go @@ -364,7 +364,7 @@ func (q *nosqlQueue) updateAckLevel( } // Ignore possibly delayed message - if queueMetadata.ClusterAckLevels[clusterName] > messageID { + if ackLevel, ok := queueMetadata.ClusterAckLevels[clusterName]; ok && ackLevel >= messageID { return nil } diff --git a/common/persistence/persistence-tests/queuePersistenceTest.go b/common/persistence/persistence-tests/queuePersistenceTest.go index ff18f5e9437..bdbc3b946b9 100644 --- a/common/persistence/persistence-tests/queuePersistenceTest.go +++ b/common/persistence/persistence-tests/queuePersistenceTest.go @@ -122,6 +122,9 @@ func (s *QueuePersistenceSuite) TestQueueMetadataOperations() { err = s.UpdateAckLevel(ctx, 25, "test2") s.Require().NoError(err) + err = s.UpdateAckLevel(ctx, 24, "test2") + s.Require().NoError(err) + clusterAckLevels, err = s.GetAckLevels(ctx) s.Require().NoError(err) s.Assert().Len(clusterAckLevels, 2) diff --git a/common/persistence/sql/sqlQueue.go b/common/persistence/sql/sqlQueue.go index 2a5341ddc18..9b12a46bf44 100644 --- a/common/persistence/sql/sqlQueue.go +++ b/common/persistence/sql/sqlQueue.go @@ -22,9 +22,8 @@ package sql import ( "context" - "fmt" - "database/sql" + "fmt" "github.com/uber/cadence/common/log" "github.com/uber/cadence/common/persistence" @@ -133,7 +132,7 @@ func (q *sqlQueue) UpdateAckLevel( } // Ignore possibly delayed message - if clusterAckLevels[clusterName] > messageID { + if ackLevel, ok := clusterAckLevels[clusterName]; ok && ackLevel >= messageID { return nil } diff --git a/service/frontend/adminHandler.go b/service/frontend/adminHandler.go index aaf441f415e..3c1bb70eae3 100644 --- a/service/frontend/adminHandler.go +++ b/service/frontend/adminHandler.go @@ -657,14 +657,10 @@ func (adh *adminHandlerImpl) GetDomainReplicationMessages( if request.LastProcessedMessageID != nil { lastProcessedMessageID = request.GetLastProcessedMessageID() } - - if lastProcessedMessageID != defaultLastMessageID { - err := adh.GetDomainReplicationQueue().UpdateAckLevel(ctx, lastProcessedMessageID, request.GetClusterName()) - if err != nil { - adh.GetLogger().Warn("Failed to update domain replication queue ack level.", - tag.TaskID(int64(lastProcessedMessageID)), - tag.ClusterName(request.GetClusterName())) - } + if err := adh.GetDomainReplicationQueue().UpdateAckLevel(ctx, lastProcessedMessageID, request.GetClusterName()); err != nil { + adh.GetLogger().Warn("Failed to update domain replication queue ack level.", + tag.TaskID(int64(lastProcessedMessageID)), + tag.ClusterName(request.GetClusterName())) } return &types.GetDomainReplicationMessagesResponse{ diff --git a/service/worker/replicator/domain_replication_processor.go b/service/worker/replicator/domain_replication_processor.go index 38689e1742c..e238c9f7c7a 100644 --- a/service/worker/replicator/domain_replication_processor.go +++ b/service/worker/replicator/domain_replication_processor.go @@ -49,6 +49,7 @@ const ( func newDomainReplicationProcessor( sourceCluster string, + currentCluster string, logger log.Logger, remotePeer admin.Client, metricsClient metrics.Client, @@ -67,6 +68,7 @@ func newDomainReplicationProcessor( serviceResolver: serviceResolver, status: common.DaemonStatusInitialized, sourceCluster: sourceCluster, + currentCluster: currentCluster, logger: logger, remotePeer: remotePeer, taskExecutor: taskExecutor, @@ -85,6 +87,7 @@ type ( serviceResolver membership.ServiceResolver status int32 sourceCluster string + currentCluster string logger log.Logger remotePeer admin.Client taskExecutor domain.ReplicationTaskExecutor @@ -141,6 +144,7 @@ func (p *domainReplicationProcessor) fetchDomainReplicationTasks() { request := &types.GetDomainReplicationMessagesRequest{ LastRetrievedMessageID: common.Int64Ptr(p.lastRetrievedMessageID), LastProcessedMessageID: common.Int64Ptr(p.lastProcessedMessageID), + ClusterName: p.currentCluster, } response, err := p.remotePeer.GetDomainReplicationMessages(ctx, request) defer cancel() diff --git a/service/worker/replicator/domain_replication_processor_test.go b/service/worker/replicator/domain_replication_processor_test.go index 2fdea5233da..b77e8ae8ca3 100644 --- a/service/worker/replicator/domain_replication_processor_test.go +++ b/service/worker/replicator/domain_replication_processor_test.go @@ -45,6 +45,7 @@ type domainReplicationSuite struct { controller *gomock.Controller sourceCluster string + currentCluster string taskExecutor *domain.MockReplicationTaskExecutor remoteClient *admin.MockClient domainReplicationQueue *domain.MockReplicationQueue @@ -62,6 +63,7 @@ func (s *domainReplicationSuite) SetupTest() { resource := resource.NewTest(s.controller, metrics.Worker) s.sourceCluster = "active" + s.currentCluster = "standby" s.taskExecutor = domain.NewMockReplicationTaskExecutor(s.controller) s.domainReplicationQueue = domain.NewMockReplicationQueue(s.controller) s.remoteClient = resource.RemoteAdminClient @@ -69,6 +71,7 @@ func (s *domainReplicationSuite) SetupTest() { serviceResolver.EXPECT().Lookup(s.sourceCluster).Return(resource.GetHostInfo(), nil).AnyTimes() s.replicationProcessor = newDomainReplicationProcessor( s.sourceCluster, + s.currentCluster, resource.GetLogger(), s.remoteClient, resource.GetMetricsClient(), diff --git a/service/worker/replicator/replicator.go b/service/worker/replicator/replicator.go index 54c3bb48f35..f95faec6de9 100644 --- a/service/worker/replicator/replicator.go +++ b/service/worker/replicator/replicator.go @@ -88,6 +88,7 @@ func (r *Replicator) Start() error { if clusterName != currentClusterName { processor := newDomainReplicationProcessor( clusterName, + currentClusterName, r.logger.WithTags(tag.ComponentReplicationTaskProcessor, tag.SourceCluster(clusterName)), r.clientBean.GetRemoteAdminClient(clusterName), r.metricsClient,