Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

drivers: update ordering of events in StartTask to fix executor leak #24495

Merged
merged 2 commits into from
Dec 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .changelog/24495.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
```release-note:bug
drivers: fix executor leak when drivers error starting tasks
```
26 changes: 15 additions & 11 deletions drivers/exec/driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -456,7 +456,7 @@ func (d *Driver) RecoverTask(handle *drivers.TaskHandle) error {
return nil
}

func (d *Driver) StartTask(cfg *drivers.TaskConfig) (*drivers.TaskHandle, *drivers.DriverNetwork, error) {
func (d *Driver) StartTask(cfg *drivers.TaskConfig) (handle *drivers.TaskHandle, network *drivers.DriverNetwork, err error) {
if _, ok := d.tasks.Get(cfg.ID); ok {
return nil, nil, fmt.Errorf("task with ID %q already started", cfg.ID)
}
Expand All @@ -481,7 +481,7 @@ func (d *Driver) StartTask(cfg *drivers.TaskConfig) (*drivers.TaskHandle, *drive
}

d.logger.Info("starting task", "driver_cfg", hclog.Fmt("%+v", driverConfig))
handle := drivers.NewTaskHandle(taskHandleVersion)
handle = drivers.NewTaskHandle(taskHandleVersion)
handle.Config = cfg

pluginLogFile := filepath.Join(cfg.TaskDir().Dir, "executor.out")
Expand All @@ -492,13 +492,6 @@ func (d *Driver) StartTask(cfg *drivers.TaskConfig) (*drivers.TaskHandle, *drive
Compute: d.compute,
}

exec, pluginClient, err := executor.CreateExecutor(
d.logger.With("task_name", handle.Config.Name, "alloc_id", handle.Config.AllocID),
d.nomadConfig, executorConfig)
if err != nil {
return nil, nil, fmt.Errorf("failed to create executor: %v", err)
}

user := cfg.User
if cfg.DNS != nil {
dnsMount, err := resolvconf.GenerateDNSMount(cfg.TaskDir().Dir, cfg.DNS)
Expand All @@ -516,6 +509,19 @@ func (d *Driver) StartTask(cfg *drivers.TaskConfig) (*drivers.TaskHandle, *drive
}
d.logger.Debug("task capabilities", "capabilities", caps)

exec, pluginClient, err := executor.CreateExecutor(
d.logger.With("task_name", handle.Config.Name, "alloc_id", handle.Config.AllocID),
d.nomadConfig, executorConfig)
if err != nil {
return nil, nil, fmt.Errorf("failed to create executor: %v", err)
}
// prevent leaking executor in error scenarios
defer func() {
if err != nil {
pluginClient.Kill()
}
}()

execCmd := &executor.ExecCommand{
Cmd: driverConfig.Command,
Args: driverConfig.Args,
Expand All @@ -538,7 +544,6 @@ func (d *Driver) StartTask(cfg *drivers.TaskConfig) (*drivers.TaskHandle, *drive

ps, err := exec.Launch(execCmd)
if err != nil {
pluginClient.Kill()
return nil, nil, fmt.Errorf("failed to launch command with executor: %v", err)
}

Expand All @@ -562,7 +567,6 @@ func (d *Driver) StartTask(cfg *drivers.TaskConfig) (*drivers.TaskHandle, *drive
if err := handle.SetDriverState(&driverState); err != nil {
d.logger.Error("failed to start task, error setting driver state", "error", err)
_ = exec.Shutdown("", 0)
pluginClient.Kill()
return nil, nil, fmt.Errorf("failed to set driver state: %v", err)
}

Expand Down
64 changes: 63 additions & 1 deletion drivers/exec/driver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ import (
"github.com/hashicorp/nomad/testutil"
"github.com/shoenig/test/must"
"github.com/stretchr/testify/require"
"golang.org/x/sys/unix"
)

type mockIDValidator struct{}
Expand Down Expand Up @@ -347,9 +348,70 @@ func TestExecDriver_StartWaitRecover(t *testing.T) {
require.NoError(t, harness.DestroyTask(task.ID, true))
}

func TestExecDriver_NoOrphanedExecutor(t *testing.T) {
ci.Parallel(t)
ctestutils.ExecCompatible(t)

ctx, cancel := context.WithCancel(context.Background())
defer cancel()

d := newExecDriverTest(t, ctx)
harness := dtestutil.NewDriverHarness(t, d)
defer harness.Kill()

config := &Config{
NoPivotRoot: false,
DefaultModePID: executor.IsolationModePrivate,
DefaultModeIPC: executor.IsolationModePrivate,
}

var data []byte
must.NoError(t, base.MsgPackEncode(&data, config))
baseConfig := &base.Config{
PluginConfig: data,
AgentConfig: &base.AgentConfig{
Driver: &base.ClientDriverConfig{
Topology: d.(*Driver).nomadConfig.Topology,
},
},
}
must.NoError(t, harness.SetConfig(baseConfig))

allocID := uuid.Generate()
taskName := "test"
task := &drivers.TaskConfig{
AllocID: allocID,
ID: uuid.Generate(),
Name: taskName,
Resources: testResources(allocID, taskName),
}

cleanup := harness.MkAllocDir(task, true)
defer cleanup()

taskConfig := map[string]interface{}{}
taskConfig["command"] = "force-an-error"
must.NoError(t, task.EncodeConcreteDriverConfig(&taskConfig))

_, _, err := harness.StartTask(task)
must.Error(t, err)
defer harness.DestroyTask(task.ID, true)

testPid := unix.Getpid()
tids, err := os.ReadDir(fmt.Sprintf("/proc/%d/task", testPid))
must.NoError(t, err)
for _, tid := range tids {
children, err := os.ReadFile(fmt.Sprintf("/proc/%d/task/%s/children", testPid, tid.Name()))
must.NoError(t, err)

pids := strings.Fields(string(children))
must.Eq(t, 0, len(pids))
}
}

// TestExecDriver_NoOrphans asserts that when the main
// task dies, the orphans in the PID namespaces are killed by the kernel
func TestExecDriver_NoOrphans(t *testing.T) {
func TestExecDriver_NoOrphanedTasks(t *testing.T) {
ci.Parallel(t)
ctestutils.ExecCompatible(t)

Expand Down
26 changes: 15 additions & 11 deletions drivers/java/driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -429,7 +429,7 @@ func (d *Driver) RecoverTask(handle *drivers.TaskHandle) error {
return nil
}

func (d *Driver) StartTask(cfg *drivers.TaskConfig) (*drivers.TaskHandle, *drivers.DriverNetwork, error) {
func (d *Driver) StartTask(cfg *drivers.TaskConfig) (handle *drivers.TaskHandle, network *drivers.DriverNetwork, err error) {
if _, ok := d.tasks.Get(cfg.ID); ok {
return nil, nil, fmt.Errorf("task with ID %q already started", cfg.ID)
}
Expand All @@ -456,7 +456,7 @@ func (d *Driver) StartTask(cfg *drivers.TaskConfig) (*drivers.TaskHandle, *drive

d.logger.Info("starting java task", "driver_cfg", hclog.Fmt("%+v", driverConfig), "args", args)

handle := drivers.NewTaskHandle(taskHandleVersion)
handle = drivers.NewTaskHandle(taskHandleVersion)
handle.Config = cfg

pluginLogFile := filepath.Join(cfg.TaskDir().Dir, "executor.out")
Expand All @@ -467,13 +467,6 @@ func (d *Driver) StartTask(cfg *drivers.TaskConfig) (*drivers.TaskHandle, *drive
Compute: d.nomadConfig.Topology.Compute(),
}

exec, pluginClient, err := executor.CreateExecutor(
d.logger.With("task_name", handle.Config.Name, "alloc_id", handle.Config.AllocID),
d.nomadConfig, executorConfig)
if err != nil {
return nil, nil, fmt.Errorf("failed to create executor: %v", err)
}

user := cfg.User
if user == "" {
user = "nobody"
Expand All @@ -495,6 +488,19 @@ func (d *Driver) StartTask(cfg *drivers.TaskConfig) (*drivers.TaskHandle, *drive
}
d.logger.Debug("task capabilities", "capabilities", caps)

exec, pluginClient, err := executor.CreateExecutor(
d.logger.With("task_name", handle.Config.Name, "alloc_id", handle.Config.AllocID),
d.nomadConfig, executorConfig)
if err != nil {
return nil, nil, fmt.Errorf("failed to create executor: %v", err)
}
// prevent leaking executor in error scenarios
defer func() {
if err != nil {
pluginClient.Kill()
}
}()

execCmd := &executor.ExecCommand{
Cmd: absPath,
Args: args,
Expand All @@ -516,7 +522,6 @@ func (d *Driver) StartTask(cfg *drivers.TaskConfig) (*drivers.TaskHandle, *drive

ps, err := exec.Launch(execCmd)
if err != nil {
pluginClient.Kill()
return nil, nil, fmt.Errorf("failed to launch command with executor: %v", err)
}

Expand All @@ -540,7 +545,6 @@ func (d *Driver) StartTask(cfg *drivers.TaskConfig) (*drivers.TaskHandle, *drive
if err := handle.SetDriverState(&driverState); err != nil {
d.logger.Error("failed to start task, error setting driver state", "error", err)
exec.Shutdown("", 0)
pluginClient.Kill()
return nil, nil, fmt.Errorf("failed to set driver state: %v", err)
}

Expand Down
Loading