Skip to content

Commit

Permalink
fix the issue that the task is not cancelled after transfering owners
Browse files Browse the repository at this point in the history
Signed-off-by: Yang Keao <yangkeao@chunibyo.icu>
  • Loading branch information
YangKeao committed Nov 28, 2024
1 parent 37a1f42 commit 04da03a
Show file tree
Hide file tree
Showing 5 changed files with 135 additions and 18 deletions.
2 changes: 1 addition & 1 deletion pkg/timer/tablestore/store.go
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ func (s *tableTimerStoreCore) List(ctx context.Context, cond api.Cond) ([]*api.T
}
defer back()

if sessVars := sctx.GetSessionVars(); sessVars.GetEnableIndexMerge() {
if sessVars := sctx.GetSessionVars(); !sessVars.GetEnableIndexMerge() {
// Enable index merge is used to make sure filtering timers with tags quickly.
// Currently, we are using multi-value index to index tags for timers which requires index merge enabled.
// see: https://docs.pingcap.com/tidb/dev/choose-index#use-a-multi-valued-index
Expand Down
41 changes: 28 additions & 13 deletions pkg/ttl/ttlworker/task_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,27 +50,27 @@ const setTTLTaskFinishedTemplate = `UPDATE mysql.tidb_ttl_task
SET status = 'finished',
status_update_time = %?,
state = %?
WHERE job_id = %? AND scan_id = %?`
WHERE job_id = %? AND scan_id = %? AND status = 'running' AND owner_id = %?`

func setTTLTaskFinishedSQL(jobID string, scanID int64, state *cache.TTLTaskState, now time.Time) (string, []any, error) {
func setTTLTaskFinishedSQL(jobID string, scanID int64, state *cache.TTLTaskState, now time.Time, ownerID string) (string, []any, error) {
stateStr, err := json.Marshal(state)
if err != nil {
return "", nil, err
}
return setTTLTaskFinishedTemplate, []any{now.Format(timeFormat), string(stateStr), jobID, scanID}, nil
return setTTLTaskFinishedTemplate, []any{now.Format(timeFormat), string(stateStr), jobID, scanID, ownerID}, nil
}

const updateTTLTaskHeartBeatTempalte = `UPDATE mysql.tidb_ttl_task
SET state = %?,
owner_hb_time = %?
WHERE job_id = %? AND scan_id = %?`
WHERE job_id = %? AND scan_id = %? AND owner_id = %?`

func updateTTLTaskHeartBeatSQL(jobID string, scanID int64, now time.Time, state *cache.TTLTaskState) (string, []any, error) {
func updateTTLTaskHeartBeatSQL(jobID string, scanID int64, now time.Time, state *cache.TTLTaskState, ownerID string) (string, []any, error) {
stateStr, err := json.Marshal(state)
if err != nil {
return "", nil, err
}
return updateTTLTaskHeartBeatTempalte, []any{string(stateStr), now.Format(timeFormat), jobID, scanID}, nil
return updateTTLTaskHeartBeatTempalte, []any{string(stateStr), now.Format(timeFormat), jobID, scanID, ownerID}, nil
}

const countRunningTasks = "SELECT count(1) FROM mysql.tidb_ttl_task WHERE status = 'running'"
Expand Down Expand Up @@ -449,14 +449,19 @@ func (m *taskManager) updateHeartBeat(ctx context.Context, se session.Session, n
}

intest.Assert(se.GetSessionVars().Location().String() == now.Location().String())
sql, args, err := updateTTLTaskHeartBeatSQL(task.JobID, task.ScanID, now, state)
sql, args, err := updateTTLTaskHeartBeatSQL(task.JobID, task.ScanID, now, state, m.id)
if err != nil {
return err
}
_, err = se.ExecuteSQL(ctx, sql, args...)
if err != nil {
return errors.Wrapf(err, "execute sql: %s", sql)
}

if se.GetSessionVars().StmtCtx.AffectedRows() != 1 {
return errors.Errorf("fail to update task status, maybe the owner is not myself, affected rows: %d",
se.GetSessionVars().StmtCtx.AffectedRows())
}
}
return nil
}
Expand Down Expand Up @@ -492,7 +497,7 @@ func (m *taskManager) reportTaskFinished(se session.Session, now time.Time, task
}

intest.Assert(se.GetSessionVars().Location().String() == now.Location().String())
sql, args, err := setTTLTaskFinishedSQL(task.JobID, task.ScanID, state, now)
sql, args, err := setTTLTaskFinishedSQL(task.JobID, task.ScanID, state, now, m.id)
if err != nil {
return err
}
Expand All @@ -504,6 +509,10 @@ func (m *taskManager) reportTaskFinished(se session.Session, now time.Time, task
if err != nil {
return err
}
if se.GetSessionVars().StmtCtx.AffectedRows() != 1 {
return errors.Errorf("fail to update task status, maybe the owner is not myself or task is not running, affected rows: %d",
se.GetSessionVars().StmtCtx.AffectedRows())
}

return nil
}
Expand All @@ -514,31 +523,37 @@ func (m *taskManager) checkInvalidTask(se session.Session) {
ownRunningTask := make([]*runningScanTask, 0, len(m.runningTasks))

for _, task := range m.runningTasks {
logger := logutil.Logger(m.ctx).With(zap.String("jobID", task.JobID), zap.Int64("scanID", task.ScanID))

sql, args := cache.SelectFromTTLTaskWithID(task.JobID, task.ScanID)

timeoutCtx, cancel := context.WithTimeout(m.ctx, ttlInternalSQLTimeout)
rows, err := se.ExecuteSQL(timeoutCtx, sql, args...)
cancel()
if err != nil {
logutil.Logger(m.ctx).Warn("fail to execute sql", zap.String("sql", sql), zap.Any("args", args), zap.Error(err))
logger.Warn("fail to execute sql", zap.String("sql", sql), zap.Any("args", args), zap.Error(err))
task.cancel()
continue
}
if len(rows) == 0 {
logutil.Logger(m.ctx).Warn("didn't find task", zap.String("jobID", task.JobID), zap.Int64("scanID", task.ScanID))
logger.Warn("didn't find task")
task.cancel()
continue
}
t, err := cache.RowToTTLTask(se, rows[0])
if err != nil {
logutil.Logger(m.ctx).Warn("fail to get task", zap.Error(err))
logger.Warn("fail to get task", zap.Error(err))
task.cancel()
continue
}

if t.OwnerID == m.id {
ownRunningTask = append(ownRunningTask, task)
if t.OwnerID != m.id {
logger.Warn("task owner changed", zap.String("myOwnerID", m.id), zap.String("taskOwnerID", t.OwnerID))
task.cancel()
continue
}

ownRunningTask = append(ownRunningTask, task)
}

m.runningTasks = ownRunningTask
Expand Down
85 changes: 85 additions & 0 deletions pkg/ttl/ttlworker/task_manager_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -356,3 +356,88 @@ func TestMeetTTLRunningTasks(t *testing.T) {
require.False(t, dom.TTLJobManager().TaskManager().MeetTTLRunningTasks(3, cache.TaskStatusWaiting))
require.True(t, dom.TTLJobManager().TaskManager().MeetTTLRunningTasks(3, cache.TaskStatusRunning))
}

func TestTaskCancelledAfterHeartbeatTimeout(t *testing.T) {
store, dom := testkit.CreateMockStoreAndDomain(t)
pool := wrapPoolForTest(dom.SysSessionPool())
waitAndStopTTLManager(t, dom)
tk := testkit.NewTestKit(t, store)
sessionFactory := sessionFactory(t, store)
se := sessionFactory()

tk.MustExec("set global tidb_ttl_running_tasks = 128")
defer tk.MustExec("set global tidb_ttl_running_tasks = -1")

tk.MustExec("create table test.t(id int, created_at datetime) ttl=created_at + interval 1 day")
table, err := dom.InfoSchema().TableByName(context.Background(), model.NewCIStr("test"), model.NewCIStr("t"))
require.NoError(t, err)
// 4 tasks are inserted into the table
for i := 0; i < 4; i++ {
sql := fmt.Sprintf("insert into mysql.tidb_ttl_task(job_id,table_id,scan_id,expire_time,created_time) values ('test-job', %d, %d, NOW(), NOW())", table.Meta().ID, i)
tk.MustExec(sql)
}
isc := cache.NewInfoSchemaCache(time.Second)
require.NoError(t, isc.Update(se))

workers := []ttlworker.Worker{}
for j := 0; j < 8; j++ {
scanWorker := ttlworker.NewMockScanWorker(t)
scanWorker.Start()
workers = append(workers, scanWorker)
}

now := se.Now()
m1 := ttlworker.NewTaskManager(context.Background(), pool, isc, "task-manager-1", store)
m1.SetScanWorkers4Test(workers[0:4])
m1.RescheduleTasks(se, now)
m2 := ttlworker.NewTaskManager(context.Background(), pool, isc, "task-manager-2", store)
m2.SetScanWorkers4Test(workers[4:])

// All tasks should be scheduled to m1 and running
tk.MustQuery("select count(1) from mysql.tidb_ttl_task where status = 'running' and owner_id = 'task-manager-1'").Check(testkit.Rows("4"))

var cancelCount atomic.Uint32
for i := 0; i < 4; i++ {
task := m1.GetRunningTasks()[i]
task.SetCancel(func() {
cancelCount.Add(1)
})
}

// After a period of time, the tasks lost heartbeat and will be re-asisgned to m2
now = now.Add(time.Hour)
m2.RescheduleTasks(se, now)

// All tasks should be scheduled to m2 and running
tk.MustQuery("select count(1) from mysql.tidb_ttl_task where status = 'running' and owner_id = 'task-manager-2'").Check(testkit.Rows("4"))

// Then m1 cannot update the heartbeat of its task
require.Error(t, m1.UpdateHeartBeat(context.Background(), se, now.Add(time.Hour)))
tk.MustQuery("select owner_hb_time from mysql.tidb_ttl_task").Check(testkit.Rows(
now.Format(time.DateTime),
now.Format(time.DateTime),
now.Format(time.DateTime),
now.Format(time.DateTime),
))

// m2 can successfully update the heartbeat
require.NoError(t, m2.UpdateHeartBeat(context.Background(), se, now.Add(time.Hour)))
tk.MustQuery("select owner_hb_time from mysql.tidb_ttl_task").Check(testkit.Rows(
now.Add(time.Hour).Format(time.DateTime),
now.Add(time.Hour).Format(time.DateTime),
now.Add(time.Hour).Format(time.DateTime),
now.Add(time.Hour).Format(time.DateTime),
))

// Also m1 cannot finish the task
for _, task := range m1.GetRunningTasks() {
task.SetResult(nil)
}
m1.CheckFinishedTask(se, now)
tk.MustQuery("select count(1) from mysql.tidb_ttl_task where status = 'running'").Check(testkit.Rows("4"))
require.Equal(t, uint32(4), cancelCount.Load())

// Then the tasks in m1 should be cancelled
m1.CheckInvalidTask(se)
require.Equal(t, uint32(8), cancelCount.Load())
}
15 changes: 15 additions & 0 deletions pkg/ttl/ttlworker/task_manager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,21 @@ func (t *runningScanTask) SetResult(err error) {
t.result = t.ttlScanTask.result(err)
}

// SetCancel sets the cancel function of the task
func (t *runningScanTask) SetCancel(cancel func()) {
t.cancel = cancel
}

// CheckInvalidTask is an exported version of checkInvalidTask
func (m *taskManager) CheckInvalidTask(se session.Session) {
m.checkInvalidTask(se)
}

// UpdateHeartBeat is an exported version of updateHeartBeat
func (m *taskManager) UpdateHeartBeat(ctx context.Context, se session.Session, now time.Time) error {
return m.updateHeartBeat(ctx, se, now)
}

func TestResizeWorkers(t *testing.T) {
tbl := newMockTTLTbl(t, "t1")

Expand Down
10 changes: 6 additions & 4 deletions pkg/ttl/ttlworker/timer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import (
"github.com/pingcap/tidb/pkg/sessionctx/variable"
timerapi "github.com/pingcap/tidb/pkg/timer/api"
"github.com/pingcap/tidb/pkg/util/logutil"
"github.com/pingcap/tidb/pkg/util/timeutil"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
)
Expand Down Expand Up @@ -350,8 +351,9 @@ func TestTTLTimerHookOnEvent(t *testing.T) {
require.Equal(t, summaryData, timer.SummaryData)
adapter.AssertExpectations(t)

tz := timeutil.SystemLocation()
// job not exists but table ttl not enabled
watermark := time.Unix(3600*123, 0)
watermark := time.Unix(3600*123, 0).In(tz)
require.NoError(t, cli.UpdateTimer(ctx, timer.ID, timerapi.WithSetWatermark(watermark)))
timer = triggerTestTimer(t, store, timer.ID)
adapter.On("GetJob", ctx, data.TableID, data.PhysicalID, timer.EventID).
Expand All @@ -373,7 +375,7 @@ func TestTTLTimerHookOnEvent(t *testing.T) {
require.Equal(t, oldSummary, timer.SummaryData)

// job not exists but timer disabled
watermark = time.Unix(3600*456, 0)
watermark = time.Unix(3600*456, 0).In(tz)
require.NoError(t, cli.UpdateTimer(ctx, timer.ID, timerapi.WithSetWatermark(watermark), timerapi.WithSetEnable(false)))
timer = triggerTestTimer(t, store, timer.ID)
adapter.On("GetJob", ctx, data.TableID, data.PhysicalID, timer.EventID).
Expand All @@ -394,7 +396,7 @@ func TestTTLTimerHookOnEvent(t *testing.T) {
require.NoError(t, cli.UpdateTimer(ctx, timer.ID, timerapi.WithSetEnable(true)))

// job not exists but event start too early
watermark = time.Unix(3600*789, 0)
watermark = time.Unix(3600*789, 0).In(tz)
require.NoError(t, cli.UpdateTimer(ctx, timer.ID, timerapi.WithSetWatermark(watermark)))
timer = triggerTestTimer(t, store, timer.ID)
adapter.On("Now").Return(timer.EventStart.Add(11*time.Minute), nil).Once()
Expand Down Expand Up @@ -584,7 +586,7 @@ func TestGetTTLSchedulePolicy(t *testing.T) {
JobInterval: "",
})
require.Equal(t, timerapi.SchedEventInterval, tp)
require.Equal(t, model.OldDefaultTTLJobInterval, expr)
require.Equal(t, model.DefaultJobIntervalStr, expr)
_, err = timerapi.CreateSchedEventPolicy(tp, expr)
require.NoError(t, err)
}

0 comments on commit 04da03a

Please sign in to comment.