Skip to content

Commit

Permalink
resolve race
Browse files Browse the repository at this point in the history
Signed-off-by: husharp <ihusharp@gmail.com>
  • Loading branch information
HuSharp committed Jul 8, 2024
1 parent aecb4ab commit dff1346
Show file tree
Hide file tree
Showing 8 changed files with 71 additions and 46 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ pd-analysis:
pd-heartbeat-bench:
cd tools && CGO_ENABLED=0 go build -gcflags '$(GCFLAGS)' -ldflags '$(LDFLAGS)' -o $(BUILD_BIN_PATH)/pd-heartbeat-bench pd-heartbeat-bench/main.go
simulator:
cd tools && CGO_ENABLED=0 go build -gcflags '$(GCFLAGS)' -ldflags '$(LDFLAGS)' -o $(BUILD_BIN_PATH)/pd-simulator pd-simulator/main.go
cd tools && GOEXPERIMENT=$(BUILD_GOEXPERIMENT) CGO_ENABLED=$(BUILD_CGO_ENABLED) go build $(BUILD_FLAGS) -gcflags '$(GCFLAGS)' -ldflags '$(LDFLAGS)' -o $(BUILD_BIN_PATH)/pd-simulator pd-simulator/main.go
regions-dump:
cd tools && CGO_ENABLED=0 go build -gcflags '$(GCFLAGS)' -ldflags '$(LDFLAGS)' -o $(BUILD_BIN_PATH)/regions-dump regions-dump/main.go
stores-dump:
Expand Down
4 changes: 3 additions & 1 deletion pkg/core/region.go
Original file line number Diff line number Diff line change
Expand Up @@ -2211,8 +2211,10 @@ func NewTestRegionInfo(regionID, storeID uint64, start, end []byte, opts ...Regi
}

// TraverseRegions executes a function on all regions.
// ONLY for simulator now and function need to be self-locked.
// ONLY for simulator now and only for READ.
func (r *RegionsInfo) TraverseRegions(lockedFunc func(*RegionInfo)) {
r.t.RLock()
defer r.t.RUnlock()

Check warning on line 2217 in pkg/core/region.go

View check run for this annotation

Codecov / codecov/patch

pkg/core/region.go#L2216-L2217

Added lines #L2216 - L2217 were not covered by tests
for _, item := range r.regions {
lockedFunc(item.RegionInfo)
}
Expand Down
29 changes: 14 additions & 15 deletions tools/pd-simulator/simulator/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"strconv"
"strings"
"sync"
"sync/atomic"
"time"

"github.com/pingcap/errors"
Expand Down Expand Up @@ -50,7 +51,7 @@ type Client interface {
}

const (
pdTimeout = time.Second
pdTimeout = 3 * time.Second
maxInitClusterRetries = 100
// retry to get leader URL
leaderChangedWaitTime = 100 * time.Millisecond
Expand All @@ -62,13 +63,13 @@ var (
errFailInitClusterID = errors.New("[pd] failed to get cluster id")
PDHTTPClient pdHttp.Client
SD pd.ServiceDiscovery
ClusterID uint64
ClusterID atomic.Uint64
)

// requestHeader returns a header for fixed ClusterID.
func requestHeader() *pdpb.RequestHeader {
return &pdpb.RequestHeader{
ClusterId: ClusterID,
ClusterId: ClusterID.Load(),
}
}

Expand Down Expand Up @@ -205,12 +206,11 @@ func (c *client) reportRegionHeartbeat(ctx context.Context, stream pdpb.PD_Regio
defer wg.Done()
for {
select {
case r := <-c.reportRegionHeartbeatCh:
if r == nil {
case region := <-c.reportRegionHeartbeatCh:
if region == nil {
simutil.Logger.Error("report nil regionHeartbeat error",
zap.String("tag", c.tag), zap.Error(errors.New("nil region")))
}
region := r.Clone()
request := &pdpb.RegionHeartbeatRequest{
Header: requestHeader(),
Region: region.GetMeta(),
Expand Down Expand Up @@ -281,9 +281,8 @@ func (c *client) PutStore(ctx context.Context, store *metapb.Store) error {
return nil
}

func (c *client) StoreHeartbeat(ctx context.Context, stats *pdpb.StoreStats) error {
func (c *client) StoreHeartbeat(ctx context.Context, newStats *pdpb.StoreStats) error {
ctx, cancel := context.WithTimeout(ctx, pdTimeout)
newStats := typeutil.DeepClone(stats, core.StoreStatsFactory)
resp, err := c.pdClient().StoreHeartbeat(ctx, &pdpb.StoreHeartbeatRequest{
Header: requestHeader(),
Stats: newStats,
Expand Down Expand Up @@ -382,8 +381,8 @@ func getLeaderURL(ctx context.Context, conn *grpc.ClientConn) (string, *grpc.Cli
if members.GetHeader().GetError() != nil {
return "", nil, errors.New(members.GetHeader().GetError().String())
}
ClusterID = members.GetHeader().GetClusterId()
if ClusterID == 0 {
ClusterID.Store(members.GetHeader().GetClusterId())
if ClusterID.Load() == 0 {
return "", nil, errors.New("cluster id is 0")
}
if members.GetLeader() == nil {
Expand Down Expand Up @@ -413,9 +412,9 @@ func (rc *RetryClient) PutStore(ctx context.Context, store *metapb.Store) error
return err
}

func (rc *RetryClient) StoreHeartbeat(ctx context.Context, stats *pdpb.StoreStats) error {
func (rc *RetryClient) StoreHeartbeat(ctx context.Context, newStats *pdpb.StoreStats) error {
_, err := rc.requestWithRetry(func() (any, error) {
err := rc.client.StoreHeartbeat(ctx, stats)
err := rc.client.StoreHeartbeat(ctx, newStats)
return nil, err
})
return err
Expand Down Expand Up @@ -466,10 +465,10 @@ retry:
break retry
}
}
if ClusterID == 0 {
if ClusterID.Load() == 0 {
return "", nil, errors.WithStack(errFailInitClusterID)
}
simutil.Logger.Info("get cluster id successfully", zap.Uint64("cluster-id", ClusterID))
simutil.Logger.Info("get cluster id successfully", zap.Uint64("cluster-id", ClusterID.Load()))

// Check if the cluster is already bootstrapped.
ctx, cancel := context.WithTimeout(ctx, pdTimeout)
Expand Down Expand Up @@ -543,7 +542,7 @@ func PutPDConfig(config *sc.PDConfig) error {
}

func ChooseToHaltPDSchedule(halt bool) {
HaltSchedule = halt
HaltSchedule.Store(halt)
PDHTTPClient.SetConfig(context.Background(), map[string]any{
"schedule.halt-scheduling": strconv.FormatBool(halt),
})
Expand Down
23 changes: 11 additions & 12 deletions tools/pd-simulator/simulator/drive.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
"strconv"
"strings"
"sync"
"sync/atomic"
"time"

"github.com/pingcap/errors"
Expand Down Expand Up @@ -122,7 +123,7 @@ func (d *Driver) allocID() error {
return err
}
ctx, cancel := context.WithTimeout(context.Background(), requestTimeout)
rootPath := path.Join("/pd", strconv.FormatUint(ClusterID, 10))
rootPath := path.Join("/pd", strconv.FormatUint(ClusterID.Load(), 10))
allocIDPath := path.Join(rootPath, "alloc_id")
_, err = etcdClient.Put(ctx, allocIDPath, string(typeutil.Uint64ToBytes(maxID+1000)))
if err != nil {
Expand Down Expand Up @@ -176,24 +177,20 @@ func (d *Driver) Tick() {
d.wg.Wait()
}

var HaltSchedule = false
var HaltSchedule atomic.Bool

// Check checks if the simulation is completed.
func (d *Driver) Check() bool {
if !HaltSchedule {
if !HaltSchedule.Load() {
return false
}
length := uint64(len(d.conn.Nodes) + 1)
var stats []info.StoreStats
var stores []*metapb.Store
for index, s := range d.conn.Nodes {
if index >= length {
length = index + 1
}
for _, s := range d.conn.Nodes {
s.statsMutex.RLock()
stores = append(stores, s.Store)
}
stats := make([]info.StoreStats, length)
for index, node := range d.conn.Nodes {
stats[index] = *node.stats
stats = append(stats, *s.stats)
s.statsMutex.RUnlock()
}
return d.simCase.Checker(stores, d.raftEngine.regionsInfo, stats)
}
Expand Down Expand Up @@ -252,11 +249,13 @@ func (d *Driver) GetBootstrapInfo(r *RaftEngine) (*metapb.Store, *metapb.Region,

func (d *Driver) updateNodeAvailable() {
for storeID, n := range d.conn.Nodes {
n.statsMutex.Lock()
if n.hasExtraUsedSpace {
n.stats.StoreStats.Available = n.stats.StoreStats.Capacity - uint64(d.raftEngine.regionsInfo.GetStoreRegionSize(storeID)) - uint64(d.simConfig.RaftStore.ExtraUsedSpace)
} else {
n.stats.StoreStats.Available = n.stats.StoreStats.Capacity - uint64(d.raftEngine.regionsInfo.GetStoreRegionSize(storeID))
}
n.statsMutex.Unlock()
}
}

Expand Down
5 changes: 3 additions & 2 deletions tools/pd-simulator/simulator/event.go
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,8 @@ func (e *DownNode) Run(raft *RaftEngine, _ int64) bool {
}
node.Stop()

raft.TraverseRegions(func(region *core.RegionInfo) {
regions := raft.GetRegions()
for _, region := range regions {
storeIDs := region.GetStoreIDs()
if _, ok := storeIDs[node.Id]; ok {
downPeer := &pdpb.PeerStats{
Expand All @@ -250,6 +251,6 @@ func (e *DownNode) Run(raft *RaftEngine, _ int64) bool {
region = region.Clone(core.WithDownPeers(append(region.GetDownPeers(), downPeer)))
raft.SetRegion(region)
}
})
}
return true
}
26 changes: 16 additions & 10 deletions tools/pd-simulator/simulator/node.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import (
"github.com/tikv/pd/pkg/core"
"github.com/tikv/pd/pkg/ratelimit"
"github.com/tikv/pd/pkg/utils/syncutil"
"github.com/tikv/pd/pkg/utils/typeutil"
"github.com/tikv/pd/tools/pd-simulator/simulator/cases"
sc "github.com/tikv/pd/tools/pd-simulator/simulator/config"
"github.com/tikv/pd/tools/pd-simulator/simulator/info"
Expand All @@ -51,7 +52,7 @@ type Node struct {
cancel context.CancelFunc
raftEngine *RaftEngine
limiter *ratelimit.RateLimiter
sizeMutex syncutil.Mutex
statsMutex syncutil.RWMutex
hasExtraUsedSpace bool
snapStats []*pdpb.SnapshotStat
// PD client
Expand Down Expand Up @@ -179,12 +180,15 @@ func (n *Node) storeHeartBeat() {
if n.GetNodeState() != metapb.NodeState_Preparing && n.GetNodeState() != metapb.NodeState_Serving {
return
}
ctx, cancel := context.WithTimeout(n.ctx, pdTimeout)
n.statsMutex.Lock()
stats := make([]*pdpb.SnapshotStat, len(n.snapStats))
copy(stats, n.snapStats)
n.snapStats = n.snapStats[:0]
n.stats.SnapshotStats = stats
err := n.client.StoreHeartbeat(ctx, &n.stats.StoreStats)
newStats := typeutil.DeepClone(&n.stats.StoreStats, core.StoreStatsFactory)
n.statsMutex.Unlock()
ctx, cancel := context.WithTimeout(n.ctx, pdTimeout)
err := n.client.StoreHeartbeat(ctx, newStats)
if err != nil {
simutil.Logger.Info("report store heartbeat error",
zap.Uint64("node-id", n.GetId()),
Expand All @@ -194,8 +198,8 @@ func (n *Node) storeHeartBeat() {
}

func (n *Node) compaction() {
n.sizeMutex.Lock()
defer n.sizeMutex.Unlock()
n.statsMutex.Lock()
defer n.statsMutex.Unlock()
n.stats.Available += n.stats.ToCompactionSize
n.stats.UsedSize -= n.stats.ToCompactionSize
n.stats.ToCompactionSize = 0
Expand All @@ -211,7 +215,7 @@ func (n *Node) regionHeartBeat() {
if region == nil {
simutil.Logger.Fatal("region not found")
}
err := n.client.RegionHeartbeat(ctx, region)
err := n.client.RegionHeartbeat(ctx, region.Clone())
if err != nil {
simutil.Logger.Info("report region heartbeat error",
zap.Uint64("node-id", n.Id),
Expand Down Expand Up @@ -267,19 +271,21 @@ func (n *Node) Stop() {
}

func (n *Node) incUsedSize(size uint64) {
n.sizeMutex.Lock()
defer n.sizeMutex.Unlock()
n.statsMutex.Lock()
defer n.statsMutex.Unlock()
n.stats.Available -= size
n.stats.UsedSize += size
}

func (n *Node) decUsedSize(size uint64) {
n.sizeMutex.Lock()
defer n.sizeMutex.Unlock()
n.statsMutex.Lock()
defer n.statsMutex.Unlock()
n.stats.ToCompactionSize += size
}

func (n *Node) registerSnapStats(generate, send, total uint64) {
n.statsMutex.Lock()
defer n.statsMutex.Unlock()
stat := pdpb.SnapshotStat{
GenerateDurationSec: generate,
SendDurationSec: send,
Expand Down
17 changes: 14 additions & 3 deletions tools/pd-simulator/simulator/raft.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,10 +82,11 @@ func NewRaftEngine(conf *cases.Case, conn *Connection, storeConfig *config.SimCo
}

func (r *RaftEngine) stepRegions() {
r.TraverseRegions(func(region *core.RegionInfo) {
regions := r.GetRegions()
for _, region := range regions {
r.stepLeader(region)
r.stepSplit(region)
})
}
}

func (r *RaftEngine) stepLeader(region *core.RegionInfo) {
Expand Down Expand Up @@ -228,7 +229,10 @@ func (r *RaftEngine) electNewLeader(region *core.RegionInfo) *metapb.Peer {
func (r *RaftEngine) GetRegion(regionID uint64) *core.RegionInfo {
r.RLock()
defer r.RUnlock()
return r.regionsInfo.GetRegion(regionID)
if region := r.regionsInfo.GetRegion(regionID); region != nil {
return region.Clone()
}
return nil
}

// GetRegionChange returns a list of RegionID for a given store.
Expand Down Expand Up @@ -256,6 +260,13 @@ func (r *RaftEngine) TraverseRegions(lockedFunc func(*core.RegionInfo)) {
r.regionsInfo.TraverseRegions(lockedFunc)
}

// GetRegions gets all RegionInfo from regionMap
func (r *RaftEngine) GetRegions() []*core.RegionInfo {
r.RLock()
defer r.RUnlock()
return r.regionsInfo.GetRegions()
}

// SetRegion sets the RegionInfo with regionID
func (r *RaftEngine) SetRegion(region *core.RegionInfo) []*core.RegionInfo {
r.Lock()
Expand Down
11 changes: 9 additions & 2 deletions tools/pd-simulator/simulator/task.go
Original file line number Diff line number Diff line change
Expand Up @@ -517,20 +517,25 @@ func processSnapshot(n *Node, stat *snapshotStat, speed uint64) bool {
return true
}
if stat.status == pending {
if stat.action == generate && n.stats.SendingSnapCount > maxSnapGeneratorPoolSize {
n.statsMutex.RLock()
sendSnapshot, receiveSnapshot := n.stats.SendingSnapCount, n.stats.ReceivingSnapCount
n.statsMutex.RUnlock()
if stat.action == generate && sendSnapshot > maxSnapGeneratorPoolSize {
return false
}
if stat.action == receive && n.stats.ReceivingSnapCount > maxSnapReceivePoolSize {
if stat.action == receive && receiveSnapshot > maxSnapReceivePoolSize {
return false
}
stat.status = running
stat.generateStart = time.Now()
n.statsMutex.Lock()
// If the statement is true, it will start to send or Receive the snapshot.
if stat.action == generate {
n.stats.SendingSnapCount++
} else {
n.stats.ReceivingSnapCount++
}
n.statsMutex.Unlock()
}

// store should Generate/Receive snapshot by chunk size.
Expand All @@ -548,11 +553,13 @@ func processSnapshot(n *Node, stat *snapshotStat, speed uint64) bool {
totalSec := uint64(time.Since(stat.start).Seconds()) * speed
generateSec := uint64(time.Since(stat.generateStart).Seconds()) * speed
n.registerSnapStats(generateSec, 0, totalSec)
n.statsMutex.Lock()
if stat.action == generate {
n.stats.SendingSnapCount--
} else {
n.stats.ReceivingSnapCount--
}
n.statsMutex.Unlock()
}
return true
}

0 comments on commit dff1346

Please sign in to comment.