diff --git a/protocol/streaming/grpc/grpc_streaming_manager.go b/protocol/streaming/grpc/grpc_streaming_manager.go index e5291cc1f9c..fc2f4e31ef4 100644 --- a/protocol/streaming/grpc/grpc_streaming_manager.go +++ b/protocol/streaming/grpc/grpc_streaming_manager.go @@ -81,7 +81,9 @@ func NewGrpcStreamingManager( for { select { case <-grpcStreamingManager.ticker.C: + grpcStreamingManager.Lock() grpcStreamingManager.FlushStreamUpdates() + grpcStreamingManager.Unlock() case <-grpcStreamingManager.done: grpcStreamingManager.logger.Info( "GRPC Stream poller goroutine shutting down", @@ -205,9 +207,9 @@ func (sm *GrpcStreamingManagerImpl) Stop() { sm.done <- true } -// SendSnapshot groups updates by their clob pair ids and -// sends messages to the subscribers. It groups out updates differently -// and bypasses the buffer. +// SendSnapshot sends messages to a particular subscriber without buffering. +// Note this method requires the lock and assumes that the lock has already been +// acquired by the caller. func (sm *GrpcStreamingManagerImpl) SendSnapshot( offchainUpdates *clobtypes.OffchainUpdates, subscriptionId uint32, @@ -388,6 +390,8 @@ func (sm *GrpcStreamingManagerImpl) AddUpdatesToCache( } // FlushStreamUpdates takes in a map of clob pair id to stream updates and emits them to subscribers. +// Note this method requires the lock and assumes that the lock has already been +// acquired by the caller. func (sm *GrpcStreamingManagerImpl) FlushStreamUpdates() { defer metrics.ModuleMeasureSince( metrics.FullNodeGrpc, @@ -395,9 +399,6 @@ func (sm *GrpcStreamingManagerImpl) FlushStreamUpdates() { time.Now(), ) - sm.Lock() - defer sm.Unlock() - // Non-blocking send updates through subscriber's buffered channel. // If the buffer is full, drop the subscription. idsToRemove := make([]uint32, 0) @@ -446,6 +447,10 @@ func (sm *GrpcStreamingManagerImpl) InitializeNewGrpcStreams( sm.Lock() defer sm.Unlock() + // Flush any pending updates before sending the snapshot to avoid + // race conditions with the snapshot. + sm.FlushStreamUpdates() + updatesByClobPairId := make(map[uint32]*clobtypes.OffchainUpdates) for subscriptionId, subscription := range sm.orderbookSubscriptions { subscription.initialize.Do(