diff --git a/pkg/timer/tablestore/store.go b/pkg/timer/tablestore/store.go index bbc00f8646b70..1e0fdbcbc07b2 100644 --- a/pkg/timer/tablestore/store.go +++ b/pkg/timer/tablestore/store.go @@ -117,7 +117,7 @@ func (s *tableTimerStoreCore) List(ctx context.Context, cond api.Cond) ([]*api.T } defer back() - if sessVars := sctx.GetSessionVars(); sessVars.GetEnableIndexMerge() { + if sessVars := sctx.GetSessionVars(); !sessVars.GetEnableIndexMerge() { // Enable index merge is used to make sure filtering timers with tags quickly. // Currently, we are using multi-value index to index tags for timers which requires index merge enabled. // see: https://docs.pingcap.com/tidb/dev/choose-index#use-a-multi-valued-index diff --git a/pkg/ttl/ttlworker/task_manager.go b/pkg/ttl/ttlworker/task_manager.go index cf4e531008eed..9285d86970e4d 100644 --- a/pkg/ttl/ttlworker/task_manager.go +++ b/pkg/ttl/ttlworker/task_manager.go @@ -50,27 +50,27 @@ const setTTLTaskFinishedTemplate = `UPDATE mysql.tidb_ttl_task SET status = 'finished', status_update_time = %?, state = %? - WHERE job_id = %? AND scan_id = %?` + WHERE job_id = %? AND scan_id = %? AND status = 'running' AND owner_id = %?` -func setTTLTaskFinishedSQL(jobID string, scanID int64, state *cache.TTLTaskState, now time.Time) (string, []any, error) { +func setTTLTaskFinishedSQL(jobID string, scanID int64, state *cache.TTLTaskState, now time.Time, ownerID string) (string, []any, error) { stateStr, err := json.Marshal(state) if err != nil { return "", nil, err } - return setTTLTaskFinishedTemplate, []any{now.Format(timeFormat), string(stateStr), jobID, scanID}, nil + return setTTLTaskFinishedTemplate, []any{now.Format(timeFormat), string(stateStr), jobID, scanID, ownerID}, nil } const updateTTLTaskHeartBeatTempalte = `UPDATE mysql.tidb_ttl_task SET state = %?, owner_hb_time = %? - WHERE job_id = %? AND scan_id = %?` + WHERE job_id = %? AND scan_id = %? AND owner_id = %?` -func updateTTLTaskHeartBeatSQL(jobID string, scanID int64, now time.Time, state *cache.TTLTaskState) (string, []any, error) { +func updateTTLTaskHeartBeatSQL(jobID string, scanID int64, now time.Time, state *cache.TTLTaskState, ownerID string) (string, []any, error) { stateStr, err := json.Marshal(state) if err != nil { return "", nil, err } - return updateTTLTaskHeartBeatTempalte, []any{string(stateStr), now.Format(timeFormat), jobID, scanID}, nil + return updateTTLTaskHeartBeatTempalte, []any{string(stateStr), now.Format(timeFormat), jobID, scanID, ownerID}, nil } const countRunningTasks = "SELECT count(1) FROM mysql.tidb_ttl_task WHERE status = 'running'" @@ -451,7 +451,7 @@ func (m *taskManager) updateHeartBeat(ctx context.Context, se session.Session, n } intest.Assert(se.GetSessionVars().Location().String() == now.Location().String()) - sql, args, err := updateTTLTaskHeartBeatSQL(task.JobID, task.ScanID, now, state) + sql, args, err := updateTTLTaskHeartBeatSQL(task.JobID, task.ScanID, now, state, m.id) if err != nil { return err } @@ -459,6 +459,11 @@ func (m *taskManager) updateHeartBeat(ctx context.Context, se session.Session, n if err != nil { return errors.Wrapf(err, "execute sql: %s", sql) } + + if se.GetSessionVars().StmtCtx.AffectedRows() != 1 { + return errors.Errorf("fail to update task status, maybe the owner is not myself (%s), affected rows: %d", + m.id, se.GetSessionVars().StmtCtx.AffectedRows()) + } } return nil } @@ -494,7 +499,7 @@ func (m *taskManager) reportTaskFinished(se session.Session, now time.Time, task } intest.Assert(se.GetSessionVars().Location().String() == now.Location().String()) - sql, args, err := setTTLTaskFinishedSQL(task.JobID, task.ScanID, state, now) + sql, args, err := setTTLTaskFinishedSQL(task.JobID, task.ScanID, state, now, m.id) if err != nil { return err } @@ -506,6 +511,10 @@ func (m *taskManager) reportTaskFinished(se session.Session, now time.Time, task if err != nil { return err } + if se.GetSessionVars().StmtCtx.AffectedRows() != 1 { + return errors.Errorf("fail to update task status, maybe the owner is not myself (%s) or task is not running, affected rows: %d", + m.id, se.GetSessionVars().StmtCtx.AffectedRows()) + } return nil } @@ -516,31 +525,37 @@ func (m *taskManager) checkInvalidTask(se session.Session) { ownRunningTask := make([]*runningScanTask, 0, len(m.runningTasks)) for _, task := range m.runningTasks { + logger := logutil.Logger(m.ctx).With(zap.String("jobID", task.JobID), zap.Int64("scanID", task.ScanID)) + sql, args := cache.SelectFromTTLTaskWithID(task.JobID, task.ScanID) timeoutCtx, cancel := context.WithTimeout(m.ctx, ttlInternalSQLTimeout) rows, err := se.ExecuteSQL(timeoutCtx, sql, args...) cancel() if err != nil { - logutil.Logger(m.ctx).Warn("fail to execute sql", zap.String("sql", sql), zap.Any("args", args), zap.Error(err)) + logger.Warn("fail to execute sql", zap.String("sql", sql), zap.Any("args", args), zap.Error(err)) task.cancel() continue } if len(rows) == 0 { - logutil.Logger(m.ctx).Warn("didn't find task", zap.String("jobID", task.JobID), zap.Int64("scanID", task.ScanID)) + logger.Warn("didn't find task") task.cancel() continue } t, err := cache.RowToTTLTask(se, rows[0]) if err != nil { - logutil.Logger(m.ctx).Warn("fail to get task", zap.Error(err)) + logger.Warn("fail to get task", zap.Error(err)) task.cancel() continue } - if t.OwnerID == m.id { - ownRunningTask = append(ownRunningTask, task) + if t.OwnerID != m.id { + logger.Warn("task owner changed", zap.String("myOwnerID", m.id), zap.String("taskOwnerID", t.OwnerID)) + task.cancel() + continue } + + ownRunningTask = append(ownRunningTask, task) } m.runningTasks = ownRunningTask diff --git a/pkg/ttl/ttlworker/task_manager_integration_test.go b/pkg/ttl/ttlworker/task_manager_integration_test.go index 94cfe44348d3c..6d34f79c2c476 100644 --- a/pkg/ttl/ttlworker/task_manager_integration_test.go +++ b/pkg/ttl/ttlworker/task_manager_integration_test.go @@ -429,3 +429,100 @@ func TestShrinkScanWorkerTimeout(t *testing.T) { require.NoError(t, m.ResizeDelWorkers(0)) close(blockCancelCh) } + +func TestTaskCancelledAfterHeartbeatTimeout(t *testing.T) { + store, dom := testkit.CreateMockStoreAndDomain(t) + pool := wrapPoolForTest(dom.SysSessionPool()) + waitAndStopTTLManager(t, dom) + tk := testkit.NewTestKit(t, store) + sessionFactory := sessionFactory(t, store) + se := sessionFactory() + + tk.MustExec("set global tidb_ttl_running_tasks = 128") + defer tk.MustExec("set global tidb_ttl_running_tasks = -1") + + tk.MustExec("create table test.t(id int, created_at datetime) ttl=created_at + interval 1 day") + table, err := dom.InfoSchema().TableByName(context.Background(), model.NewCIStr("test"), model.NewCIStr("t")) + require.NoError(t, err) + // 4 tasks are inserted into the table + for i := 0; i < 4; i++ { + sql := fmt.Sprintf("insert into mysql.tidb_ttl_task(job_id,table_id,scan_id,expire_time,created_time) values ('test-job', %d, %d, NOW(), NOW())", table.Meta().ID, i) + tk.MustExec(sql) + } + isc := cache.NewInfoSchemaCache(time.Second) + require.NoError(t, isc.Update(se)) + + workers := []ttlworker.Worker{} + for j := 0; j < 8; j++ { + scanWorker := ttlworker.NewMockScanWorker(t) + scanWorker.Start() + workers = append(workers, scanWorker) + } + + now := se.Now() + m1 := ttlworker.NewTaskManager(context.Background(), pool, isc, "task-manager-1", store) + m1.SetScanWorkers4Test(workers[0:4]) + m1.RescheduleTasks(se, now) + m2 := ttlworker.NewTaskManager(context.Background(), pool, isc, "task-manager-2", store) + m2.SetScanWorkers4Test(workers[4:]) + + // All tasks should be scheduled to m1 and running + tk.MustQuery("select count(1) from mysql.tidb_ttl_task where status = 'running' and owner_id = 'task-manager-1'").Check(testkit.Rows("4")) + + var cancelCount atomic.Uint32 + for i := 0; i < 4; i++ { + task := m1.GetRunningTasks()[i] + task.SetCancel(func() { + cancelCount.Add(1) + }) + } + + // After a period of time, the tasks lost heartbeat and will be re-asisgned to m2 + now = now.Add(time.Hour) + m2.RescheduleTasks(se, now) + + // All tasks should be scheduled to m2 and running + tk.MustQuery("select count(1) from mysql.tidb_ttl_task where status = 'running' and owner_id = 'task-manager-2'").Check(testkit.Rows("4")) + + // Then m1 cannot update the heartbeat of its task + require.Error(t, m1.UpdateHeartBeat(context.Background(), se, now.Add(time.Hour))) + tk.MustQuery("select owner_hb_time from mysql.tidb_ttl_task").Check(testkit.Rows( + now.Format(time.DateTime), + now.Format(time.DateTime), + now.Format(time.DateTime), + now.Format(time.DateTime), + )) + + // m2 can successfully update the heartbeat + require.NoError(t, m2.UpdateHeartBeat(context.Background(), se, now.Add(time.Hour))) + tk.MustQuery("select owner_hb_time from mysql.tidb_ttl_task").Check(testkit.Rows( + now.Add(time.Hour).Format(time.DateTime), + now.Add(time.Hour).Format(time.DateTime), + now.Add(time.Hour).Format(time.DateTime), + now.Add(time.Hour).Format(time.DateTime), + )) + + // Although m1 cannot finish the task. It'll also try to cancel the task. + for _, task := range m1.GetRunningTasks() { + task.SetResult(nil) + } + m1.CheckFinishedTask(se, now) + tk.MustQuery("select count(1) from mysql.tidb_ttl_task where status = 'running'").Check(testkit.Rows("4")) + require.Equal(t, uint32(4), cancelCount.Load()) + + // Then the tasks in m1 should be cancelled again in `CheckInvalidTask`. + m1.CheckInvalidTask(se) + require.Equal(t, uint32(8), cancelCount.Load()) + + // m2 can finish the task + for _, task := range m2.GetRunningTasks() { + task.SetResult(nil) + } + m2.CheckFinishedTask(se, now) + tk.MustQuery("select status, state, owner_id from mysql.tidb_ttl_task").Sort().Check(testkit.Rows( + `finished {"total_rows":0,"success_rows":0,"error_rows":0,"scan_task_err":""} task-manager-2`, + `finished {"total_rows":0,"success_rows":0,"error_rows":0,"scan_task_err":""} task-manager-2`, + `finished {"total_rows":0,"success_rows":0,"error_rows":0,"scan_task_err":""} task-manager-2`, + `finished {"total_rows":0,"success_rows":0,"error_rows":0,"scan_task_err":""} task-manager-2`, + )) +} diff --git a/pkg/ttl/ttlworker/task_manager_test.go b/pkg/ttl/ttlworker/task_manager_test.go index 2ff3bb212a043..bd2381b1e3eff 100644 --- a/pkg/ttl/ttlworker/task_manager_test.go +++ b/pkg/ttl/ttlworker/task_manager_test.go @@ -101,6 +101,21 @@ func (t *runningScanTask) SetResult(err error) { t.result = t.ttlScanTask.result(err) } +// SetCancel sets the cancel function of the task +func (t *runningScanTask) SetCancel(cancel func()) { + t.cancel = cancel +} + +// CheckInvalidTask is an exported version of checkInvalidTask +func (m *taskManager) CheckInvalidTask(se session.Session) { + m.checkInvalidTask(se) +} + +// UpdateHeartBeat is an exported version of updateHeartBeat +func (m *taskManager) UpdateHeartBeat(ctx context.Context, se session.Session, now time.Time) error { + return m.updateHeartBeat(ctx, se, now) +} + func TestResizeWorkers(t *testing.T) { tbl := newMockTTLTbl(t, "t1")