diff --git a/client/allocrunner/taskrunner/plugin_supervisor_hook.go b/client/allocrunner/taskrunner/plugin_supervisor_hook.go index 21491cf77c..6c246c74cb 100644 --- a/client/allocrunner/taskrunner/plugin_supervisor_hook.go +++ b/client/allocrunner/taskrunner/plugin_supervisor_hook.go @@ -38,6 +38,7 @@ type csiPluginSupervisorHook struct { // eventEmitter is used to emit events to the task eventEmitter ti.EventEmitter + lifecycle ti.TaskLifecycle shutdownCtx context.Context shutdownCancelFn context.CancelFunc @@ -54,6 +55,7 @@ type csiPluginSupervisorHookConfig struct { clientStateDirPath string events ti.EventEmitter runner *TaskRunner + lifecycle ti.TaskLifecycle capabilities *drivers.Capabilities logger hclog.Logger } @@ -90,6 +92,7 @@ func newCSIPluginSupervisorHook(config *csiPluginSupervisorHookConfig) *csiPlugi hook := &csiPluginSupervisorHook{ alloc: config.runner.Alloc(), runner: config.runner, + lifecycle: config.lifecycle, logger: config.logger, task: task, mountPoint: pluginRoot, @@ -209,20 +212,27 @@ func (h *csiPluginSupervisorHook) ensureSupervisorLoop(ctx context.Context) { t := time.NewTimer(0) + // We're in Poststart at this point, so if we can't connect within + // this deadline, assume it's broken so we can restart the task + startCtx, startCancelFn := context.WithTimeout(ctx, 30*time.Second) + defer startCancelFn() + + var err error + var pluginHealthy bool + // Step 1: Wait for the plugin to initially become available. WAITFORREADY: for { select { - case <-ctx.Done(): + case <-startCtx.Done(): + h.kill(ctx, fmt.Errorf("CSI plugin failed probe: %v", err)) return case <-t.C: - pluginHealthy, err := h.supervisorLoopOnce(ctx, client) + pluginHealthy, err = h.supervisorLoopOnce(startCtx, client) if err != nil || !pluginHealthy { - h.logger.Debug("CSI Plugin not ready", "error", err) - - // Plugin is not yet returning healthy, because we want to optimise for - // quickly bringing a plugin online, we use a short timeout here. - // TODO(dani): Test with more plugins and adjust. + h.logger.Debug("CSI plugin not ready", "error", err) + // Use only a short delay here to optimize for quickly + // bringing up a plugin t.Reset(5 * time.Second) continue } @@ -240,13 +250,11 @@ WAITFORREADY: // Step 2: Register the plugin with the catalog. deregisterPluginFn, err := h.registerPlugin(client, socketPath) if err != nil { - h.logger.Error("CSI plugin registration failed", "error", err) - event := structs.NewTaskEvent(structs.TaskPluginUnhealthy) - event.SetMessage(fmt.Sprintf("failed to register plugin: %s, reason: %v", h.task.CSIPluginConfig.ID, err)) - h.eventEmitter.EmitEvent(event) + h.kill(ctx, fmt.Errorf("CSI plugin failed to register: %v", err)) } - // Step 3: Start the lightweight supervisor loop. + // Step 3: Start the lightweight supervisor loop. At this point, failures + // don't cause the task to restart t.Reset(0) for { select { @@ -271,7 +279,7 @@ WAITFORREADY: if h.previousHealthState && !pluginHealthy { event := structs.NewTaskEvent(structs.TaskPluginUnhealthy) if err != nil { - event.SetMessage(fmt.Sprintf("error: %v", err)) + event.SetMessage(fmt.Sprintf("Error: %v", err)) } else { event.SetMessage("Unknown Reason") } @@ -359,7 +367,7 @@ func (h *csiPluginSupervisorHook) supervisorLoopOnce(ctx context.Context, client healthy, err := client.PluginProbe(probeCtx) if err != nil { - return false, fmt.Errorf("failed to probe plugin: %v", err) + return false, err } return healthy, nil @@ -378,6 +386,21 @@ func (h *csiPluginSupervisorHook) Stop(_ context.Context, req *interfaces.TaskSt return nil } +func (h *csiPluginSupervisorHook) kill(ctx context.Context, reason error) { + h.logger.Error("killing task because plugin failed", "error", reason) + event := structs.NewTaskEvent(structs.TaskPluginUnhealthy) + event.SetMessage(fmt.Sprintf("Error: %v", reason.Error())) + h.eventEmitter.EmitEvent(event) + + if err := h.lifecycle.Kill(ctx, + structs.NewTaskEvent(structs.TaskKilling). + SetFailsTask(). + SetDisplayMessage("CSI plugin did not become healthy before timeout"), + ); err != nil { + h.logger.Error("failed to kill task", "kill_reason", reason, "error", err) + } +} + func ensureMountpointInserted(mounts []*drivers.MountConfig, mount *drivers.MountConfig) []*drivers.MountConfig { for _, mnt := range mounts { if mnt.IsEqual(mount) { diff --git a/client/allocrunner/taskrunner/task_runner_hooks.go b/client/allocrunner/taskrunner/task_runner_hooks.go index b3f44a8b98..62ff26c4b6 100644 --- a/client/allocrunner/taskrunner/task_runner_hooks.go +++ b/client/allocrunner/taskrunner/task_runner_hooks.go @@ -76,6 +76,7 @@ func (tr *TaskRunner) initHooks() { clientStateDirPath: tr.clientConfig.StateDir, events: tr, runner: tr, + lifecycle: tr, capabilities: tr.driverCapabilities, logger: hookLogger, }))