Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
The table of contents is too big for display.
Diff view
Diff view
  •  
  •  
  •  
19 changes: 13 additions & 6 deletions cmd/nvidia-ctk-installer/container/container.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,12 +49,15 @@ type Options struct {
// mount.
ExecutablePath string
// EnabledCDI indicates whether CDI should be enabled.
EnableCDI bool
RuntimeName string
RuntimeDir string
SetAsDefault bool
RestartMode string
HostRootMount string
EnableCDI bool
EnableNRI bool
RuntimeName string
RuntimeDir string
SetAsDefault bool
RestartMode string
HostRootMount string
NRIPluginIndex string
NRISocket string

ConfigSources []string
}
Expand Down Expand Up @@ -128,6 +131,10 @@ func (o Options) UpdateConfig(cfg engine.Interface) error {
cfg.EnableCDI()
}

if o.EnableNRI {
cfg.EnableNRI()
}

return nil
}

Expand Down
146 changes: 146 additions & 0 deletions cmd/nvidia-ctk-installer/container/runtime/nri/plugin.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
package nri

import (
"context"
"fmt"
"os"

"github.com/containerd/nri/pkg/api"
nriplugin "github.com/containerd/nri/pkg/stub"
"sigs.k8s.io/yaml"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it make sense to use a comma-separated list of devices instead of having to parse the YAML?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, I am open to that


"github.com/NVIDIA/nvidia-container-toolkit/internal/logger"
)

// Compile-time interface checks
var (
_ nriplugin.Plugin = (*Plugin)(nil)
)

const (
// nodeResourceCDIDeviceKey is the prefix of the key used for CDI device annotations.
nodeResourceCDIDeviceKey = "cdi-devices.noderesource.dev"
// nriCDIDeviceKey is the prefix of the key used for CDI device annotations.
nriCDIDeviceKey = "cdi-devices.nri.io"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you comment on why we need two prefixes here? Are these not determined by US? (https://github.com/NVIDIA/gpu-operator/pull/1950/files#diff-e6f52ba1392796db4c79e078d3f1067c50e3bfde9d90f3aaaad3eb3e3f4d84fbR20-R21). Why not only respond to one of them?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't. Yes, these should be defined by us and one annotation should be sufficient. This is taken the from example plugin code of containerd/nri

// defaultNRISocket represents the default path of the NRI socket
defaultNRISocket = "/var/run/nri/nri.sock"
)

type Plugin struct {
logger logger.Interface

stub nriplugin.Stub
}

// NewPlugin creates a new NRI plugin for injecting CDI devices
func NewPlugin(logger logger.Interface) *Plugin {
return &Plugin{
logger: logger,
}
}

// CreateContainer handles container creation requests.
func (p *Plugin) CreateContainer(_ context.Context, pod *api.PodSandbox, ctr *api.Container) (*api.ContainerAdjustment, []*api.ContainerUpdate, error) {
adjust := &api.ContainerAdjustment{}

if err := p.injectCDIDevices(pod, ctr, adjust); err != nil {
return nil, nil, err
}

return adjust, nil, nil
}

func (p *Plugin) injectCDIDevices(pod *api.PodSandbox, ctr *api.Container, a *api.ContainerAdjustment) error {
devices, err := parseCDIDevices(ctr.Name, pod.Annotations)
if err != nil {
return err
}

if len(devices) == 0 {
p.logger.Debugf("%s: no CDI devices annotated...", containerName(pod, ctr))
return nil
}

for _, name := range devices {
a.AddCDIDevice(
&api.CDIDevice{
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As far as I am aware, this introduces restructions on compatible containerd / cri-o versions.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given that we've moved to native CDI, the minimum supported containerd is now v1.7. Will that be affected by this change?

Name: name,
},
)
p.logger.Infof("%s: injected CDI device %q...", containerName(pod, ctr), name)
}

return nil
}

func parseCDIDevices(ctr string, annotations map[string]string) ([]string, error) {
var (
cdiDevices []string
)

annotation := getAnnotation(annotations, nodeResourceCDIDeviceKey, nriCDIDeviceKey, ctr)
if len(annotation) == 0 {
return nil, nil
}

if err := yaml.Unmarshal(annotation, &cdiDevices); err != nil {
return nil, fmt.Errorf("invalid CDI device annotation %q: %w", string(annotation), err)
}

return cdiDevices, nil
}

func getAnnotation(annotations map[string]string, mainKey, oldKey, ctr string) []byte {
for _, key := range []string{
mainKey + "/container." + ctr,
oldKey + "/container." + ctr,
mainKey + "/pod",
oldKey + "/pod",
mainKey,
oldKey,
} {
Comment on lines +94 to +101
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of having to deal with two keys, could we rather have a single function that we call for each of the keys?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The plan is to definitely deal with just one annotation key for now. This was taken from the example plugin code

if value, ok := annotations[key]; ok {
return []byte(value)
}
}

return nil
}

// Construct a container name for log messages.
func containerName(pod *api.PodSandbox, container *api.Container) string {
if pod != nil {
return pod.Name + "/" + container.Name
}
return container.Name
}

// Start starts the NRI plugin
func (p *Plugin) Start(ctx context.Context, nriSocketPath, nriPluginIdx string) error {
if len(nriSocketPath) == 0 {
nriSocketPath = defaultNRISocket
}
_, err := os.Stat(nriSocketPath)
if err != nil {
return fmt.Errorf("failed to find valid nri socket in %s: %w", nriSocketPath, err)
}

var pluginOpts []nriplugin.Option
pluginOpts = append(pluginOpts, nriplugin.WithPluginIdx(nriPluginIdx))
pluginOpts = append(pluginOpts, nriplugin.WithSocketPath(nriSocketPath))
Comment on lines +128 to +130
Copy link
Member

@elezar elezar Dec 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
var pluginOpts []nriplugin.Option
pluginOpts = append(pluginOpts, nriplugin.WithPluginIdx(nriPluginIdx))
pluginOpts = append(pluginOpts, nriplugin.WithSocketPath(nriSocketPath))
pluginOpts := []nriplugin.Option{
nriplugin.WithPluginIdx(nriPluginIdx),
nriplugin.WithSocketPath(nriSocketPath),
}

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks

if p.stub, err = nriplugin.New(p, pluginOpts...); err != nil {
return fmt.Errorf("failed to initialise plugin at %s: %w", nriSocketPath, err)
}
err = p.stub.Start(ctx)
if err != nil {
return fmt.Errorf("plugin exited with error: %w", err)
}
return nil
}

// Stop stops the NRI plugin
func (p *Plugin) Stop() {
if p != nil && p.stub != nil {
p.stub.Stop()
}
}
23 changes: 23 additions & 0 deletions cmd/nvidia-ctk-installer/container/runtime/runtime.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ const (
// defaultRuntimeName specifies the NVIDIA runtime to be use as the default runtime if setting the default runtime is enabled
defaultRuntimeName = "nvidia"
defaultHostRootMount = "/host"
defaultNRIPluginIdx = "10"
defaultNRISocket = "/var/run/nri/nri.sock"

runtimeSpecificDefault = "RUNTIME_SPECIFIC_DEFAULT"
)
Expand Down Expand Up @@ -94,6 +96,27 @@ func Flags(opts *Options) []cli.Flag {
Destination: &opts.EnableCDI,
Sources: cli.EnvVars("RUNTIME_ENABLE_CDI"),
},
&cli.BoolFlag{
Name: "enable-nri-in-runtime",
Usage: "Enable NRI in the configured runtime",
Destination: &opts.EnableNRI,
Value: true,
Sources: cli.EnvVars("RUNTIME_ENABLE_NRI"),
},
&cli.StringFlag{
Name: "nri-plugin-index",
Usage: "Specify the plugin index to register to NRI",
Value: defaultNRIPluginIdx,
Destination: &opts.NRIPluginIndex,
Sources: cli.EnvVars("RUNTIME_NRI_PLUGIN_INDEX"),
},
&cli.StringFlag{
Name: "nri-socket",
Usage: "Specify the path to the NRI socket file to register the NRI plugin server",
Value: defaultNRISocket,
Destination: &opts.NRISocket,
Sources: cli.EnvVars("RUNTIME_NRI_SOCKET"),
},
&cli.StringFlag{
Name: "host-root",
Usage: "Specify the path to the host root to be used when restarting the runtime using systemd",
Expand Down
51 changes: 46 additions & 5 deletions cmd/nvidia-ctk-installer/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,13 @@ import (
"os/signal"
"path/filepath"
"syscall"
"time"

"github.com/urfave/cli/v3"
"golang.org/x/sys/unix"

"github.com/NVIDIA/nvidia-container-toolkit/cmd/nvidia-ctk-installer/container/runtime"
"github.com/NVIDIA/nvidia-container-toolkit/cmd/nvidia-ctk-installer/container/runtime/nri"
"github.com/NVIDIA/nvidia-container-toolkit/cmd/nvidia-ctk-installer/toolkit"
"github.com/NVIDIA/nvidia-container-toolkit/internal/info"
"github.com/NVIDIA/nvidia-container-toolkit/internal/logger"
Expand All @@ -26,6 +28,9 @@ const (
toolkitSubDir = "toolkit"

defaultRuntime = "docker"

retryBackoff = 2 * time.Second
maxRetryAttempts = 5
)

var availableRuntimes = map[string]struct{}{"docker": {}, "crio": {}, "containerd": {}}
Expand Down Expand Up @@ -70,13 +75,15 @@ func main() {
type app struct {
logger logger.Interface

toolkit *toolkit.Installer
nriPlugin *nri.Plugin
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we really need to embed this here? Does it make sense to just instantiate it when required and use a deferred shutdown?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Something like:

diff --git a/cmd/nvidia-ctk-installer/main.go b/cmd/nvidia-ctk-installer/main.go
index 39ee1abe..c725ee3e 100644
--- a/cmd/nvidia-ctk-installer/main.go
+++ b/cmd/nvidia-ctk-installer/main.go
@@ -75,15 +75,13 @@ func main() {
 type app struct {
 	logger logger.Interface
 
-	nriPlugin *nri.Plugin
-	toolkit   *toolkit.Installer
+	toolkit *toolkit.Installer
 }
 
 // NewApp creates the CLI app fro the specified options.
 func NewApp(logger logger.Interface) *cli.Command {
 	a := app{
-		logger:    logger,
-		nriPlugin: nri.NewPlugin(logger),
+		logger: logger,
 	}
 	return a.build()
 }
@@ -229,11 +227,12 @@ func (a *app) Run(ctx context.Context, c *cli.Command, o *options) error {
 	}
 
 	if !o.noDaemon {
-		if o.runtimeOptions.EnableNRI {
-			if err = a.startNRIPluginServer(ctx, o.runtimeOptions); err != nil {
-				a.logger.Errorf("unable to start NRI plugin server: %v", err)
-			}
+		nriPlugin, err := a.startNRIPluginServer(ctx, o.runtimeOptions)
+		if err != nil {
+			a.logger.Errorf("unable to start NRI plugin server: %v", err)
 		}
+		defer nriPlugin.Stop()
+
 		err = a.waitForSignal()
 		if err != nil {
 			return fmt.Errorf("unable to wait for signal: %v", err)
@@ -299,11 +298,15 @@ func (a *app) waitForSignal() error {
 	return nil
 }
 
-func (a *app) startNRIPluginServer(ctx context.Context, opts runtime.Options) error {
+func (a *app) startNRIPluginServer(ctx context.Context, opts runtime.Options) (*nri.Plugin, error) {
+	if !opts.EnableNRI {
+		return nil, nil
+	}
 	a.logger.Infof("Starting the NRI Plugin server....")
 
+	plugin := nri.NewPlugin(a.logger)
 	retriable := func() error {
-		return a.nriPlugin.Start(ctx, opts.NRISocket, opts.NRIPluginIndex)
+		return plugin.Start(ctx, opts.NRISocket, opts.NRIPluginIndex)
 	}
 	var err error
 	for i := 0; i < maxRetryAttempts; i++ {
@@ -318,19 +321,13 @@ func (a *app) startNRIPluginServer(ctx context.Context, opts runtime.Options) er
 	}
 	if err != nil {
 		a.logger.Errorf("Max retries reached %d/%d, aborting", maxRetryAttempts, maxRetryAttempts)
-		return err
+		return nil, err
 	}
-	return nil
+	return plugin, nil
 }
 
 func (a *app) shutdown(pidFile string) {
 	a.logger.Infof("Shutting Down")
-
-	if a.nriPlugin != nil {
-		a.logger.Infof("Stopping NRI plugin server...")
-		a.nriPlugin.Stop()
-	}
-
 	err := os.Remove(pidFile)
 	if err != nil {
 		a.logger.Warningf("Unable to remove pidfile: %v", err)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed, this is much better. Thanks for the suggestion!

toolkit *toolkit.Installer
}

// NewApp creates the CLI app fro the specified options.
func NewApp(logger logger.Interface) *cli.Command {
a := app{
logger: logger,
logger: logger,
nriPlugin: nri.NewPlugin(logger),
}
return a.build()
}
Expand All @@ -93,8 +100,8 @@ func (a app) build() *cli.Command {
Before: func(ctx context.Context, cmd *cli.Command) (context.Context, error) {
return ctx, a.Before(cmd, &options)
},
Action: func(_ context.Context, cmd *cli.Command) error {
return a.Run(cmd, &options)
Action: func(ctx context.Context, cmd *cli.Command) error {
return a.Run(ctx, cmd, &options)
},
Flags: []cli.Flag{
&cli.BoolFlag{
Expand Down Expand Up @@ -194,7 +201,7 @@ func (a *app) validateFlags(c *cli.Command, o *options) error {
// Run installs the NVIDIA Container Toolkit and updates the requested runtime.
// If the application is run as a daemon, the application waits and unconfigures
// the runtime on termination.
func (a *app) Run(c *cli.Command, o *options) error {
func (a *app) Run(ctx context.Context, c *cli.Command, o *options) error {
err := a.initialize(o.pidFile)
if err != nil {
return fmt.Errorf("unable to initialize: %v", err)
Expand Down Expand Up @@ -222,6 +229,11 @@ func (a *app) Run(c *cli.Command, o *options) error {
}

if !o.noDaemon {
if o.runtimeOptions.EnableNRI {
if err = a.startNRIPluginServer(ctx, o.runtimeOptions); err != nil {
a.logger.Errorf("unable to start NRI plugin server: %v", err)
}
}
err = a.waitForSignal()
if err != nil {
return fmt.Errorf("unable to wait for signal: %v", err)
Expand Down Expand Up @@ -287,9 +299,38 @@ func (a *app) waitForSignal() error {
return nil
}

func (a *app) startNRIPluginServer(ctx context.Context, opts runtime.Options) error {
a.logger.Infof("Starting the NRI Plugin server....")

retriable := func() error {
return a.nriPlugin.Start(ctx, opts.NRISocket, opts.NRIPluginIndex)
}
var err error
for i := 0; i < maxRetryAttempts; i++ {
err = retriable()
if err == nil {
break
}
if i == maxRetryAttempts-1 {
break
}
time.Sleep(retryBackoff)
}
if err != nil {
a.logger.Errorf("Max retries reached %d/%d, aborting", maxRetryAttempts, maxRetryAttempts)
return err
}
return nil
}

func (a *app) shutdown(pidFile string) {
a.logger.Infof("Shutting Down")

if a.nriPlugin != nil {
a.logger.Infof("Stopping NRI plugin server...")
a.nriPlugin.Stop()
}

err := os.Remove(pidFile)
if err != nil {
a.logger.Warningf("Unable to remove pidfile: %v", err)
Expand Down
1 change: 1 addition & 0 deletions cmd/nvidia-ctk-installer/main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -444,6 +444,7 @@ version = 2
"--pid-file=" + filepath.Join(testRoot, "toolkit.pid"),
"--restart-mode=none",
"--toolkit-source-root=" + filepath.Join(artifactRoot, "deb"),
"--enable-nri-in-runtime=false",
}

err := app.Run(context.Background(), append(testArgs, tc.args...))
Expand Down
14 changes: 11 additions & 3 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ go 1.25.0
require (
github.com/NVIDIA/go-nvlib v0.9.0
github.com/NVIDIA/go-nvml v0.13.0-1
github.com/containerd/nri v0.10.1-0.20251120153915-7d8611f87ad7
github.com/google/uuid v1.6.0
github.com/moby/sys/mountinfo v0.7.2
github.com/moby/sys/reexec v0.1.0
Expand All @@ -19,24 +20,31 @@ require (
github.com/urfave/cli/v3 v3.6.1
golang.org/x/mod v0.30.0
golang.org/x/sys v0.38.0
sigs.k8s.io/yaml v1.4.0
tags.cncf.io/container-device-interface v1.0.2-0.20251114135136-1b24d969689f
tags.cncf.io/container-device-interface/specs-go v1.0.0
)

require (
cyphar.com/go-pathrs v0.2.1 // indirect
github.com/containerd/log v0.1.0 // indirect
github.com/containerd/ttrpc v1.2.7 // indirect
github.com/cyphar/filepath-securejoin v0.6.0 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/fsnotify/fsnotify v1.7.0 // indirect
github.com/golang/protobuf v1.5.3 // indirect
github.com/hashicorp/errwrap v1.1.0 // indirect
github.com/kr/pretty v0.3.1 // indirect
github.com/knqyf263/go-plugin v0.9.0 // indirect
github.com/kr/text v0.2.0 // indirect
github.com/moby/sys/capability v0.4.0 // indirect
github.com/opencontainers/cgroups v0.0.4 // indirect
github.com/opencontainers/runtime-tools v0.9.1-0.20251114084447-edf4cb3d2116 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/rogpeppe/go-internal v1.11.0 // indirect
github.com/tetratelabs/wazero v1.9.0 // indirect
github.com/xeipuuv/gojsonpointer v0.0.0-20190905194746-02993c407bfb // indirect
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20230731190214-cbb8c96f2d6d // indirect
google.golang.org/grpc v1.57.1 // indirect
google.golang.org/protobuf v1.36.5 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
sigs.k8s.io/yaml v1.4.0 // indirect
)
Loading
Loading