Skip to content

Commit

Permalink
ttl: fix the issue that the task is not cancelled after transfering o…
Browse files Browse the repository at this point in the history
…wners (#57788)

close #57784
  • Loading branch information
YangKeao authored Dec 2, 2024
1 parent d43ba5c commit 018ab99
Show file tree
Hide file tree
Showing 4 changed files with 141 additions and 14 deletions.
2 changes: 1 addition & 1 deletion pkg/timer/tablestore/store.go
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ func (s *tableTimerStoreCore) List(ctx context.Context, cond api.Cond) ([]*api.T
}
defer back()

if sessVars := sctx.GetSessionVars(); sessVars.GetEnableIndexMerge() {
if sessVars := sctx.GetSessionVars(); !sessVars.GetEnableIndexMerge() {
// Enable index merge is used to make sure filtering timers with tags quickly.
// Currently, we are using multi-value index to index tags for timers which requires index merge enabled.
// see: https://docs.pingcap.com/tidb/dev/choose-index#use-a-multi-valued-index
Expand Down
41 changes: 28 additions & 13 deletions pkg/ttl/ttlworker/task_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,27 +50,27 @@ const setTTLTaskFinishedTemplate = `UPDATE mysql.tidb_ttl_task
SET status = 'finished',
status_update_time = %?,
state = %?
WHERE job_id = %? AND scan_id = %?`
WHERE job_id = %? AND scan_id = %? AND status = 'running' AND owner_id = %?`

func setTTLTaskFinishedSQL(jobID string, scanID int64, state *cache.TTLTaskState, now time.Time) (string, []any, error) {
func setTTLTaskFinishedSQL(jobID string, scanID int64, state *cache.TTLTaskState, now time.Time, ownerID string) (string, []any, error) {
stateStr, err := json.Marshal(state)
if err != nil {
return "", nil, err
}
return setTTLTaskFinishedTemplate, []any{now.Format(timeFormat), string(stateStr), jobID, scanID}, nil
return setTTLTaskFinishedTemplate, []any{now.Format(timeFormat), string(stateStr), jobID, scanID, ownerID}, nil
}

const updateTTLTaskHeartBeatTempalte = `UPDATE mysql.tidb_ttl_task
SET state = %?,
owner_hb_time = %?
WHERE job_id = %? AND scan_id = %?`
WHERE job_id = %? AND scan_id = %? AND owner_id = %?`

func updateTTLTaskHeartBeatSQL(jobID string, scanID int64, now time.Time, state *cache.TTLTaskState) (string, []any, error) {
func updateTTLTaskHeartBeatSQL(jobID string, scanID int64, now time.Time, state *cache.TTLTaskState, ownerID string) (string, []any, error) {
stateStr, err := json.Marshal(state)
if err != nil {
return "", nil, err
}
return updateTTLTaskHeartBeatTempalte, []any{string(stateStr), now.Format(timeFormat), jobID, scanID}, nil
return updateTTLTaskHeartBeatTempalte, []any{string(stateStr), now.Format(timeFormat), jobID, scanID, ownerID}, nil
}

const countRunningTasks = "SELECT count(1) FROM mysql.tidb_ttl_task WHERE status = 'running'"
Expand Down Expand Up @@ -451,14 +451,19 @@ func (m *taskManager) updateHeartBeat(ctx context.Context, se session.Session, n
}

intest.Assert(se.GetSessionVars().Location().String() == now.Location().String())
sql, args, err := updateTTLTaskHeartBeatSQL(task.JobID, task.ScanID, now, state)
sql, args, err := updateTTLTaskHeartBeatSQL(task.JobID, task.ScanID, now, state, m.id)
if err != nil {
return err
}
_, err = se.ExecuteSQL(ctx, sql, args...)
if err != nil {
return errors.Wrapf(err, "execute sql: %s", sql)
}

if se.GetSessionVars().StmtCtx.AffectedRows() != 1 {
return errors.Errorf("fail to update task status, maybe the owner is not myself (%s), affected rows: %d",
m.id, se.GetSessionVars().StmtCtx.AffectedRows())
}
}
return nil
}
Expand Down Expand Up @@ -494,7 +499,7 @@ func (m *taskManager) reportTaskFinished(se session.Session, now time.Time, task
}

intest.Assert(se.GetSessionVars().Location().String() == now.Location().String())
sql, args, err := setTTLTaskFinishedSQL(task.JobID, task.ScanID, state, now)
sql, args, err := setTTLTaskFinishedSQL(task.JobID, task.ScanID, state, now, m.id)
if err != nil {
return err
}
Expand All @@ -506,6 +511,10 @@ func (m *taskManager) reportTaskFinished(se session.Session, now time.Time, task
if err != nil {
return err
}
if se.GetSessionVars().StmtCtx.AffectedRows() != 1 {
return errors.Errorf("fail to update task status, maybe the owner is not myself (%s) or task is not running, affected rows: %d",
m.id, se.GetSessionVars().StmtCtx.AffectedRows())
}

return nil
}
Expand All @@ -516,31 +525,37 @@ func (m *taskManager) checkInvalidTask(se session.Session) {
ownRunningTask := make([]*runningScanTask, 0, len(m.runningTasks))

for _, task := range m.runningTasks {
logger := logutil.Logger(m.ctx).With(zap.String("jobID", task.JobID), zap.Int64("scanID", task.ScanID))

sql, args := cache.SelectFromTTLTaskWithID(task.JobID, task.ScanID)

timeoutCtx, cancel := context.WithTimeout(m.ctx, ttlInternalSQLTimeout)
rows, err := se.ExecuteSQL(timeoutCtx, sql, args...)
cancel()
if err != nil {
logutil.Logger(m.ctx).Warn("fail to execute sql", zap.String("sql", sql), zap.Any("args", args), zap.Error(err))
logger.Warn("fail to execute sql", zap.String("sql", sql), zap.Any("args", args), zap.Error(err))
task.cancel()
continue
}
if len(rows) == 0 {
logutil.Logger(m.ctx).Warn("didn't find task", zap.String("jobID", task.JobID), zap.Int64("scanID", task.ScanID))
logger.Warn("didn't find task")
task.cancel()
continue
}
t, err := cache.RowToTTLTask(se, rows[0])
if err != nil {
logutil.Logger(m.ctx).Warn("fail to get task", zap.Error(err))
logger.Warn("fail to get task", zap.Error(err))
task.cancel()
continue
}

if t.OwnerID == m.id {
ownRunningTask = append(ownRunningTask, task)
if t.OwnerID != m.id {
logger.Warn("task owner changed", zap.String("myOwnerID", m.id), zap.String("taskOwnerID", t.OwnerID))
task.cancel()
continue
}

ownRunningTask = append(ownRunningTask, task)
}

m.runningTasks = ownRunningTask
Expand Down
97 changes: 97 additions & 0 deletions pkg/ttl/ttlworker/task_manager_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -429,3 +429,100 @@ func TestShrinkScanWorkerTimeout(t *testing.T) {
require.NoError(t, m.ResizeDelWorkers(0))
close(blockCancelCh)
}

func TestTaskCancelledAfterHeartbeatTimeout(t *testing.T) {
store, dom := testkit.CreateMockStoreAndDomain(t)
pool := wrapPoolForTest(dom.SysSessionPool())
waitAndStopTTLManager(t, dom)
tk := testkit.NewTestKit(t, store)
sessionFactory := sessionFactory(t, store)
se := sessionFactory()

tk.MustExec("set global tidb_ttl_running_tasks = 128")
defer tk.MustExec("set global tidb_ttl_running_tasks = -1")

tk.MustExec("create table test.t(id int, created_at datetime) ttl=created_at + interval 1 day")
table, err := dom.InfoSchema().TableByName(context.Background(), model.NewCIStr("test"), model.NewCIStr("t"))
require.NoError(t, err)
// 4 tasks are inserted into the table
for i := 0; i < 4; i++ {
sql := fmt.Sprintf("insert into mysql.tidb_ttl_task(job_id,table_id,scan_id,expire_time,created_time) values ('test-job', %d, %d, NOW(), NOW())", table.Meta().ID, i)
tk.MustExec(sql)
}
isc := cache.NewInfoSchemaCache(time.Second)
require.NoError(t, isc.Update(se))

workers := []ttlworker.Worker{}
for j := 0; j < 8; j++ {
scanWorker := ttlworker.NewMockScanWorker(t)
scanWorker.Start()
workers = append(workers, scanWorker)
}

now := se.Now()
m1 := ttlworker.NewTaskManager(context.Background(), pool, isc, "task-manager-1", store)
m1.SetScanWorkers4Test(workers[0:4])
m1.RescheduleTasks(se, now)
m2 := ttlworker.NewTaskManager(context.Background(), pool, isc, "task-manager-2", store)
m2.SetScanWorkers4Test(workers[4:])

// All tasks should be scheduled to m1 and running
tk.MustQuery("select count(1) from mysql.tidb_ttl_task where status = 'running' and owner_id = 'task-manager-1'").Check(testkit.Rows("4"))

var cancelCount atomic.Uint32
for i := 0; i < 4; i++ {
task := m1.GetRunningTasks()[i]
task.SetCancel(func() {
cancelCount.Add(1)
})
}

// After a period of time, the tasks lost heartbeat and will be re-asisgned to m2
now = now.Add(time.Hour)
m2.RescheduleTasks(se, now)

// All tasks should be scheduled to m2 and running
tk.MustQuery("select count(1) from mysql.tidb_ttl_task where status = 'running' and owner_id = 'task-manager-2'").Check(testkit.Rows("4"))

// Then m1 cannot update the heartbeat of its task
require.Error(t, m1.UpdateHeartBeat(context.Background(), se, now.Add(time.Hour)))
tk.MustQuery("select owner_hb_time from mysql.tidb_ttl_task").Check(testkit.Rows(
now.Format(time.DateTime),
now.Format(time.DateTime),
now.Format(time.DateTime),
now.Format(time.DateTime),
))

// m2 can successfully update the heartbeat
require.NoError(t, m2.UpdateHeartBeat(context.Background(), se, now.Add(time.Hour)))
tk.MustQuery("select owner_hb_time from mysql.tidb_ttl_task").Check(testkit.Rows(
now.Add(time.Hour).Format(time.DateTime),
now.Add(time.Hour).Format(time.DateTime),
now.Add(time.Hour).Format(time.DateTime),
now.Add(time.Hour).Format(time.DateTime),
))

// Although m1 cannot finish the task. It'll also try to cancel the task.
for _, task := range m1.GetRunningTasks() {
task.SetResult(nil)
}
m1.CheckFinishedTask(se, now)
tk.MustQuery("select count(1) from mysql.tidb_ttl_task where status = 'running'").Check(testkit.Rows("4"))
require.Equal(t, uint32(4), cancelCount.Load())

// Then the tasks in m1 should be cancelled again in `CheckInvalidTask`.
m1.CheckInvalidTask(se)
require.Equal(t, uint32(8), cancelCount.Load())

// m2 can finish the task
for _, task := range m2.GetRunningTasks() {
task.SetResult(nil)
}
m2.CheckFinishedTask(se, now)
tk.MustQuery("select status, state, owner_id from mysql.tidb_ttl_task").Sort().Check(testkit.Rows(
`finished {"total_rows":0,"success_rows":0,"error_rows":0,"scan_task_err":""} task-manager-2`,
`finished {"total_rows":0,"success_rows":0,"error_rows":0,"scan_task_err":""} task-manager-2`,
`finished {"total_rows":0,"success_rows":0,"error_rows":0,"scan_task_err":""} task-manager-2`,
`finished {"total_rows":0,"success_rows":0,"error_rows":0,"scan_task_err":""} task-manager-2`,
))
}
15 changes: 15 additions & 0 deletions pkg/ttl/ttlworker/task_manager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,21 @@ func (t *runningScanTask) SetResult(err error) {
t.result = t.ttlScanTask.result(err)
}

// SetCancel sets the cancel function of the task
func (t *runningScanTask) SetCancel(cancel func()) {
t.cancel = cancel
}

// CheckInvalidTask is an exported version of checkInvalidTask
func (m *taskManager) CheckInvalidTask(se session.Session) {
m.checkInvalidTask(se)
}

// UpdateHeartBeat is an exported version of updateHeartBeat
func (m *taskManager) UpdateHeartBeat(ctx context.Context, se session.Session, now time.Time) error {
return m.updateHeartBeat(ctx, se, now)
}

func TestResizeWorkers(t *testing.T) {
tbl := newMockTTLTbl(t, "t1")

Expand Down

0 comments on commit 018ab99

Please sign in to comment.