Skip to content

Commit

Permalink
disttask: make task executor onError print error's stack. (#56618)
Browse files Browse the repository at this point in the history
close #56014
  • Loading branch information
LindaSummer authored Oct 22, 2024
1 parent 670e970 commit a00ba68
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 14 deletions.
4 changes: 3 additions & 1 deletion pkg/disttask/framework/taskexecutor/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ go_test(
],
embed = [":taskexecutor"],
flaky = True,
shard_count = 16,
shard_count = 17,
deps = [
"//pkg/disttask/framework/mock",
"//pkg/disttask/framework/mock/execute",
Expand All @@ -74,5 +74,7 @@ go_test(
"@org_golang_google_grpc//status",
"@org_uber_go_goleak//:goleak",
"@org_uber_go_mock//gomock",
"@org_uber_go_zap//:zap",
"@org_uber_go_zap//zaptest/observer",
],
)
29 changes: 17 additions & 12 deletions pkg/disttask/framework/taskexecutor/task_executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ package taskexecutor

import (
"context"
"fmt"
"sync"
"sync/atomic"
"time"
Expand Down Expand Up @@ -285,7 +286,7 @@ func (e *BaseTaskExecutor) runStep(resource *proto.StepResource) (resErr error)
taskBase := e.taskBase.Load()
task, err := e.taskTable.GetTaskByID(e.ctx, taskBase.ID)
if err != nil {
e.onError(err)
e.onError(errors.Trace(err))
return e.getError()
}
stepLogger := llog.BeginTask(e.logger.With(
Expand All @@ -301,7 +302,7 @@ func (e *BaseTaskExecutor) runStep(resource *proto.StepResource) (resErr error)

stepExecutor, err := e.GetStepExecutor(task)
if err != nil {
e.onError(err)
e.onError(errors.Trace(err))
return e.getError()
}
execute.SetFrameworkInfo(stepExecutor, resource)
Expand All @@ -310,15 +311,15 @@ func (e *BaseTaskExecutor) runStep(resource *proto.StepResource) (resErr error)
failpoint.Return(errors.New("mockExecSubtaskInitEnvErr"))
})
if err := stepExecutor.Init(runStepCtx); err != nil {
e.onError(err)
e.onError(errors.Trace(err))
return e.getError()
}

defer func() {
err := stepExecutor.Cleanup(runStepCtx)
if err != nil {
e.logger.Error("cleanup subtask exec env failed", zap.Error(err))
e.onError(err)
e.onError(errors.Trace(err))
}
}()

Expand Down Expand Up @@ -362,7 +363,7 @@ func (e *BaseTaskExecutor) runStep(resource *proto.StepResource) (resErr error)
if err == storage.ErrSubtaskNotFound {
continue
}
e.onError(err)
e.onError(errors.Trace(err))
continue
}
}
Expand Down Expand Up @@ -415,7 +416,7 @@ func (e *BaseTaskExecutor) runSubtask(ctx context.Context, stepExecutor execute.
})

if err != nil {
e.onError(err)
e.onError(errors.Trace(err))
}

finished := e.markSubTaskCanceledOrFailed(ctx, subtask)
Expand Down Expand Up @@ -453,12 +454,12 @@ func (e *BaseTaskExecutor) runSubtask(ctx context.Context, stepExecutor execute.
func (e *BaseTaskExecutor) onSubtaskFinished(ctx context.Context, executor execute.StepExecutor, subtask *proto.Subtask) {
if err := e.getError(); err == nil {
if err = executor.OnFinished(ctx, subtask); err != nil {
e.onError(err)
e.onError(errors.Trace(err))
}
}
failpoint.Inject("MockSubtaskFinishedCancel", func(val failpoint.Value) {
if val.(bool) {
e.onError(ErrCancelSubtask)
e.onError(errors.Trace(ErrCancelSubtask))
}
})

Expand Down Expand Up @@ -532,8 +533,12 @@ func (e *BaseTaskExecutor) onError(err error) {
if err == nil {
return
}
err = errors.Trace(err)
e.logger.Error("onError", zap.Error(err), zap.Stack("stack"))

if errors.HasStack(err) {
e.logger.Error("onError", zap.Error(err), zap.Stack("stack"), zap.String("error stack", fmt.Sprintf("%+v", err)))
} else {
e.logger.Error("onError", zap.Error(err), zap.Stack("stack"))
}
e.mu.Lock()
defer e.mu.Unlock()

Expand Down Expand Up @@ -575,7 +580,7 @@ func (e *BaseTaskExecutor) updateSubtaskStateAndErrorImpl(ctx context.Context, e
},
)
if err != nil {
e.onError(err)
e.onError(errors.Trace(err))
}
}

Expand Down Expand Up @@ -605,7 +610,7 @@ func (e *BaseTaskExecutor) finishSubtask(ctx context.Context, subtask *proto.Sub
},
)
if err != nil {
e.onError(err)
e.onError(errors.Trace(err))
}
}

Expand Down
51 changes: 50 additions & 1 deletion pkg/disttask/framework/taskexecutor/task_executor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,14 @@ import (

"github.com/pingcap/errors"
"github.com/pingcap/tidb/pkg/disttask/framework/mock"
"github.com/pingcap/tidb/pkg/disttask/framework/mock/execute"
mockexecute "github.com/pingcap/tidb/pkg/disttask/framework/mock/execute"
"github.com/pingcap/tidb/pkg/disttask/framework/proto"
"github.com/pingcap/tidb/pkg/disttask/framework/storage"
"github.com/pingcap/tidb/pkg/disttask/framework/taskexecutor/execute"
"github.com/stretchr/testify/require"
"go.uber.org/mock/gomock"
"go.uber.org/zap"
"go.uber.org/zap/zaptest/observer"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)
Expand Down Expand Up @@ -546,3 +548,50 @@ func TestInject(t *testing.T) {
got := e.GetResource()
require.Equal(t, r, got)
}

func throwError() error {
return errors.New("mock error")
}

func callOnError(taskExecutor *BaseTaskExecutor) {
taskExecutor.onError(throwError())
}

func throwErrorNoTrace() error {
return errors.NewNoStackError("mock error")
}

func callOnErrorNoTrace(taskExecutor *BaseTaskExecutor) {
taskExecutor.onError(throwErrorNoTrace())
}

func TestExecutorOnErrorLog(t *testing.T) {
taskExecutor := &BaseTaskExecutor{}

observedZapCore, observedLogs := observer.New(zap.ErrorLevel)
observedLogger := zap.New(observedZapCore)
taskExecutor.logger = observedLogger

{
callOnError(taskExecutor)
require.GreaterOrEqual(t, observedLogs.Len(), 1)
errLog := observedLogs.TakeAll()[0]
contextMap := errLog.ContextMap()
require.Contains(t, contextMap, "error stack")
errStack := contextMap["error stack"]
require.IsType(t, "", errStack)
errStackStr := errStack.(string)
require.Regexpf(t, `mock error[\n\t ]*`+
`github\.com/pingcap/tidb/pkg/disttask/framework/taskexecutor\.throwError`,
errStackStr,
"got err stack: %s", errStackStr)
}

{
callOnErrorNoTrace(taskExecutor)
require.GreaterOrEqual(t, observedLogs.Len(), 1)
errLog := observedLogs.TakeAll()[0]
contextMap := errLog.ContextMap()
require.NotContains(t, contextMap, "error stack")
}
}

0 comments on commit a00ba68

Please sign in to comment.