diff --git a/daemon/mgr/spec_hook.go b/daemon/mgr/spec_hook.go index e55ab202a..d535d9ada 100644 --- a/daemon/mgr/spec_hook.go +++ b/daemon/mgr/spec_hook.go @@ -2,7 +2,6 @@ package mgr import ( "context" - "os/exec" "sort" "strings" @@ -49,7 +48,7 @@ func setupHook(ctx context.Context, c *Container, specWrapper *SpecWrapper) erro } // set nvidia config - if err := setNvidiaHook(ctx, c, specWrapper); err != nil { + if err := setNvidiaHook(c, specWrapper); err != nil { return errors.Wrap(err, "failed to set nvidia prestart hook") } @@ -94,21 +93,3 @@ func (w *wrapperEmbedPrestart) Priority() int { func (w *wrapperEmbedPrestart) Hook() []string { return w.args } - -func setNvidiaHook(ctx context.Context, c *Container, spec *SpecWrapper) error { - n := c.HostConfig.NvidiaConfig - if n == nil { - return nil - } - path, err := exec.LookPath("nvidia-container-runtime-hook") - if err != nil { - return err - } - args := []string{path} - nvidiaPrestart := specs.Hook{ - Path: path, - Args: append(args, "prestart"), - } - spec.s.Hooks.Prestart = append(spec.s.Hooks.Prestart, nvidiaPrestart) - return nil -} diff --git a/daemon/mgr/spec_nvidia_hook.go b/daemon/mgr/spec_nvidia_hook.go new file mode 100644 index 000000000..e307f901d --- /dev/null +++ b/daemon/mgr/spec_nvidia_hook.go @@ -0,0 +1,40 @@ +package mgr + +import ( + "os/exec" + + "github.com/alibaba/pouch/pkg/utils" + + "github.com/opencontainers/runtime-spec/specs-go" +) + +var ( + nvidiaHookName = "nvidia-container-runtime-hook" +) + +func setNvidiaHook(c *Container, spec *SpecWrapper) error { + n := c.HostConfig.NvidiaConfig + + // to make compatible for k8s. + // if user set environments of NVIDIA, then set prestart hook + kv := utils.ConvertKVStrToMapWithNoErr(c.Config.Env) + _, hasEnvCapabilities := kv["NVIDIA_DRIVER_CAPABILITIES"] + _, hasEnvDevices := kv["NVIDIA_VISIBLE_DEVICES"] + + if n == nil && !hasEnvCapabilities && !hasEnvDevices { + return nil + } + + path, err := exec.LookPath(nvidiaHookName) + if err != nil { + return err + } + args := []string{path} + nvidiaPrestart := specs.Hook{ + Path: path, + Args: append(args, "prestart"), + } + spec.s.Hooks.Prestart = append(spec.s.Hooks.Prestart, nvidiaPrestart) + + return nil +} diff --git a/daemon/mgr/spec_nvidia_hook_test.go b/daemon/mgr/spec_nvidia_hook_test.go new file mode 100644 index 000000000..0065605ba --- /dev/null +++ b/daemon/mgr/spec_nvidia_hook_test.go @@ -0,0 +1,88 @@ +package mgr + +import ( + "os" + "os/exec" + "path" + "reflect" + "testing" + + "github.com/alibaba/pouch/apis/types" + "github.com/opencontainers/runtime-spec/specs-go" +) + +func Test_setNvidiaHook(t *testing.T) { + nvidiaHookName = "test-nvidia-container-runtime-hook" + installDir := "/usr/local/bin/" + fullname := path.Join(installDir, nvidiaHookName) + os.Remove(fullname) + os.Create(fullname) + os.Chmod(fullname, 0755) + path, _ := exec.LookPath(nvidiaHookName) + defer func() { + os.Remove(fullname) + }() + + tests := []struct { + name string + c *Container + specWrapper *SpecWrapper + expectedPrestart []specs.Hook + }{ + { + "NvidiaConfig is nil, NvidiaEnv is null", + &Container{ + HostConfig: &types.HostConfig{ + Resources: types.Resources{ + NvidiaConfig: nil, + }, + }, + Config: &types.ContainerConfig{ + Env: []string{}, + }, + }, + &SpecWrapper{ + s: &specs.Spec{ + Hooks: &specs.Hooks{ + Prestart: []specs.Hook{}, + }, + }, + }, + []specs.Hook{}, + }, + { + "NvidiaConfig is nil, NvidiaEnv not null", + &Container{ + HostConfig: &types.HostConfig{ + Resources: types.Resources{ + NvidiaConfig: nil, + }, + }, + Config: &types.ContainerConfig{ + Env: []string{"NVIDIA_DRIVER_CAPABILITIES=all", "NVIDIA_VISIBLE_DEVICES=all"}, + }, + }, + &SpecWrapper{ + s: &specs.Spec{ + Hooks: &specs.Hooks{ + Prestart: []specs.Hook{}, + }, + }, + }, + // exec.LookPath("nvidia-container-runtime-hook") return error, + []specs.Hook{specs.Hook{ + Path: path, + Args: append([]string{path}, "prestart"), + }}, + }, + } + for _, tt := range tests { + err := setNvidiaHook(tt.c, tt.specWrapper) + if err != nil { + t.Errorf("setNvidiaHook = %v, want %v", err, nil) + } + if !reflect.DeepEqual(tt.specWrapper.s.Hooks.Prestart, tt.expectedPrestart) { + t.Errorf("setNvidiaHook = %v, want %v", tt.specWrapper.s.Hooks.Poststart, tt.expectedPrestart) + } + } +}