Skip to content

Commit

Permalink
*: pass context to task (tikv#8429)
Browse files Browse the repository at this point in the history
close tikv#8386

Signed-off-by: Ryan Leung <rleungx@gmail.com>

Co-authored-by: ti-chi-bot[bot] <108142056+ti-chi-bot[bot]@users.noreply.github.com>
  • Loading branch information
rleungx and ti-chi-bot[bot] authored Jul 23, 2024
1 parent 624b6f3 commit 0c56739
Show file tree
Hide file tree
Showing 6 changed files with 43 additions and 34 deletions.
16 changes: 14 additions & 2 deletions pkg/cluster/cluster.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
package cluster

import (
"context"

"github.com/tikv/pd/pkg/core"
"github.com/tikv/pd/pkg/schedule"
"github.com/tikv/pd/pkg/schedule/placement"
Expand Down Expand Up @@ -56,8 +58,13 @@ func HandleStatsAsync(c Cluster, region *core.RegionInfo) {
}

// HandleOverlaps handles the overlap regions.
func HandleOverlaps(c Cluster, overlaps []*core.RegionInfo) {
func HandleOverlaps(ctx context.Context, c Cluster, overlaps []*core.RegionInfo) {
for _, item := range overlaps {
select {
case <-ctx.Done():
return
default:
}
if c.GetRegionStats() != nil {
c.GetRegionStats().ClearDefunctRegion(item.GetID())
}
Expand All @@ -67,7 +74,7 @@ func HandleOverlaps(c Cluster, overlaps []*core.RegionInfo) {
}

// Collect collects the cluster information.
func Collect(c Cluster, region *core.RegionInfo, hasRegionStats bool) {
func Collect(ctx context.Context, c Cluster, region *core.RegionInfo, hasRegionStats bool) {
if hasRegionStats {
// get region again from root tree. make sure the observed region is the latest.
bc := c.GetBasicCluster()
Expand All @@ -78,6 +85,11 @@ func Collect(c Cluster, region *core.RegionInfo, hasRegionStats bool) {
if region == nil {
return
}
select {
case <-ctx.Done():
return
default:
}
c.GetRegionStats().Observe(region, c.GetBasicCluster().GetRegionStores(region))
}
}
5 changes: 3 additions & 2 deletions pkg/core/region.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ package core

import (
"bytes"
"context"
"encoding/hex"
"fmt"
"math"
Expand Down Expand Up @@ -750,7 +751,7 @@ func GenerateRegionGuideFunc(enableLog bool) RegionGuideFunc {
logRunner.RunTask(
regionID,
"DebugLog",
func() {
func(context.Context) {
d(msg, fields...)
},
)
Expand All @@ -759,7 +760,7 @@ func GenerateRegionGuideFunc(enableLog bool) RegionGuideFunc {
logRunner.RunTask(
regionID,
"InfoLog",
func() {
func(context.Context) {
i(msg, fields...)
},
)
Expand Down
18 changes: 8 additions & 10 deletions pkg/mcs/scheduling/server/cluster.go
Original file line number Diff line number Diff line change
Expand Up @@ -627,10 +627,8 @@ func (c *Cluster) processRegionHeartbeat(ctx *core.MetaProcessContext, region *c
ctx.TaskRunner.RunTask(
regionID,
ratelimit.ObserveRegionStatsAsync,
func() {
if c.regionStats.RegionStatsNeedUpdate(region) {
cluster.Collect(c, region, hasRegionStats)
}
func(ctx context.Context) {
cluster.Collect(ctx, c, region, hasRegionStats)
},
)
}
Expand All @@ -639,7 +637,7 @@ func (c *Cluster) processRegionHeartbeat(ctx *core.MetaProcessContext, region *c
ctx.TaskRunner.RunTask(
regionID,
ratelimit.UpdateSubTree,
func() {
func(context.Context) {
c.CheckAndPutSubTree(region)
},
ratelimit.WithRetained(true),
Expand All @@ -663,7 +661,7 @@ func (c *Cluster) processRegionHeartbeat(ctx *core.MetaProcessContext, region *c
ctx.TaskRunner.RunTask(
regionID,
ratelimit.UpdateSubTree,
func() {
func(context.Context) {
c.CheckAndPutSubTree(region)
},
ratelimit.WithRetained(retained),
Expand All @@ -672,8 +670,8 @@ func (c *Cluster) processRegionHeartbeat(ctx *core.MetaProcessContext, region *c
ctx.TaskRunner.RunTask(
regionID,
ratelimit.HandleOverlaps,
func() {
cluster.HandleOverlaps(c, overlaps)
func(ctx context.Context) {
cluster.HandleOverlaps(ctx, c, overlaps)
},
)
}
Expand All @@ -682,8 +680,8 @@ func (c *Cluster) processRegionHeartbeat(ctx *core.MetaProcessContext, region *c
ctx.TaskRunner.RunTask(
regionID,
ratelimit.CollectRegionStatsAsync,
func() {
cluster.Collect(c, region, hasRegionStats)
func(ctx context.Context) {
cluster.Collect(ctx, c, region, hasRegionStats)
},
)
tracer.OnCollectRegionStatsFinished()
Expand Down
12 changes: 6 additions & 6 deletions pkg/ratelimit/runner.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ const (

// Runner is the interface for running tasks.
type Runner interface {
RunTask(id uint64, name string, f func(), opts ...TaskOption) error
RunTask(id uint64, name string, f func(context.Context), opts ...TaskOption) error
Start(ctx context.Context)
Stop()
}
Expand All @@ -51,7 +51,7 @@ type Runner interface {
type Task struct {
id uint64
submittedAt time.Time
f func()
f func(context.Context)
name string
// retained indicates whether the task should be dropped if the task queue exceeds maxPendingDuration.
retained bool
Expand Down Expand Up @@ -152,7 +152,7 @@ func (cr *ConcurrentRunner) run(ctx context.Context, task *Task, token *TaskToke
return
default:
}
task.f()
task.f(ctx)
if token != nil {
cr.limiter.ReleaseToken(token)
cr.processPendingTasks()
Expand Down Expand Up @@ -184,7 +184,7 @@ func (cr *ConcurrentRunner) Stop() {
}

// RunTask runs the task asynchronously.
func (cr *ConcurrentRunner) RunTask(id uint64, name string, f func(), opts ...TaskOption) error {
func (cr *ConcurrentRunner) RunTask(id uint64, name string, f func(context.Context), opts ...TaskOption) error {
task := &Task{
id: id,
name: name,
Expand Down Expand Up @@ -238,8 +238,8 @@ func NewSyncRunner() *SyncRunner {
}

// RunTask runs the task synchronously.
func (*SyncRunner) RunTask(_ uint64, _ string, f func(), _ ...TaskOption) error {
f()
func (*SyncRunner) RunTask(_ uint64, _ string, f func(context.Context), _ ...TaskOption) error {
f(context.Background())
return nil
}

Expand Down
6 changes: 3 additions & 3 deletions pkg/ratelimit/runner_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ func TestConcurrentRunner(t *testing.T) {
err := runner.RunTask(
uint64(i),
"test1",
func() {
func(context.Context) {
defer wg.Done()
time.Sleep(100 * time.Millisecond)
},
Expand All @@ -56,7 +56,7 @@ func TestConcurrentRunner(t *testing.T) {
err := runner.RunTask(
uint64(i),
"test2",
func() {
func(context.Context) {
defer wg.Done()
time.Sleep(100 * time.Millisecond)
},
Expand Down Expand Up @@ -87,7 +87,7 @@ func TestConcurrentRunner(t *testing.T) {
err := runner.RunTask(
regionID,
"test3",
func() {
func(context.Context) {
time.Sleep(time.Second)
},
)
Expand Down
20 changes: 9 additions & 11 deletions server/cluster/cluster.go
Original file line number Diff line number Diff line change
Expand Up @@ -1061,10 +1061,8 @@ func (c *RaftCluster) processRegionHeartbeat(ctx *core.MetaProcessContext, regio
ctx.MiscRunner.RunTask(
regionID,
ratelimit.ObserveRegionStatsAsync,
func() {
if c.regionStats.RegionStatsNeedUpdate(region) {
cluster.Collect(c, region, hasRegionStats)
}
func(ctx context.Context) {
cluster.Collect(ctx, c, region, hasRegionStats)
},
)
}
Expand All @@ -1073,7 +1071,7 @@ func (c *RaftCluster) processRegionHeartbeat(ctx *core.MetaProcessContext, regio
ctx.TaskRunner.RunTask(
regionID,
ratelimit.UpdateSubTree,
func() {
func(context.Context) {
c.CheckAndPutSubTree(region)
},
ratelimit.WithRetained(true),
Expand Down Expand Up @@ -1101,7 +1099,7 @@ func (c *RaftCluster) processRegionHeartbeat(ctx *core.MetaProcessContext, regio
ctx.TaskRunner.RunTask(
regionID,
ratelimit.UpdateSubTree,
func() {
func(context.Context) {
c.CheckAndPutSubTree(region)
},
ratelimit.WithRetained(retained),
Expand All @@ -1112,8 +1110,8 @@ func (c *RaftCluster) processRegionHeartbeat(ctx *core.MetaProcessContext, regio
ctx.MiscRunner.RunTask(
regionID,
ratelimit.HandleOverlaps,
func() {
cluster.HandleOverlaps(c, overlaps)
func(ctx context.Context) {
cluster.HandleOverlaps(ctx, c, overlaps)
},
)
}
Expand All @@ -1125,11 +1123,11 @@ func (c *RaftCluster) processRegionHeartbeat(ctx *core.MetaProcessContext, regio
ctx.MiscRunner.RunTask(
regionID,
ratelimit.CollectRegionStatsAsync,
func() {
func(ctx context.Context) {
// TODO: Due to the accuracy requirements of the API "/regions/check/xxx",
// region stats needs to be collected in API mode.
// We need to think of a better way to reduce this part of the cost in the future.
cluster.Collect(c, region, hasRegionStats)
cluster.Collect(ctx, c, region, hasRegionStats)
},
)

Expand All @@ -1139,7 +1137,7 @@ func (c *RaftCluster) processRegionHeartbeat(ctx *core.MetaProcessContext, regio
ctx.MiscRunner.RunTask(
regionID,
ratelimit.SaveRegionToKV,
func() {
func(context.Context) {
// If there are concurrent heartbeats from the same region, the last write will win even if
// writes to storage in the critical area. So don't use mutex to protect it.
// Not successfully saved to storage is not fatal, it only leads to longer warm-up
Expand Down

0 comments on commit 0c56739

Please sign in to comment.