Skip to content

Commit

Permalink
Merge branch 'main' of github.com:meetrajvala/go-tpm-tools into gpu-s…
Browse files Browse the repository at this point in the history
…upport
  • Loading branch information
meetrajvala committed Oct 10, 2024
2 parents 0ae2e12 + 6405dbb commit d658692
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 47 deletions.
34 changes: 14 additions & 20 deletions launcher/container_runner.go
Original file line number Diff line number Diff line change
Expand Up @@ -161,8 +161,20 @@ func NewRunner(ctx context.Context, cdClient *containerd.Client, token oauth2.To
}

if launchSpec.Experiments.EnableGpuDriverInstallation && launchSpec.InstallGpuDriver {
mounts = appendGpuDriverMounts(mounts)
specOpts = append(specOpts, oci.WithMounts(mounts))
gpuMounts := []specs.Mount{
{
Type: "volume",
Source: fmt.Sprintf("%s/lib64", gpu.InstallationHostDir),
Destination: fmt.Sprintf("%s/lib64", gpu.InstallationContainerDir),
Options: []string{"rbind", "rw"},
}, {
Type: "volume",
Source: fmt.Sprintf("%s/bin", gpu.InstallationHostDir),
Destination: fmt.Sprintf("%s/bin", gpu.InstallationContainerDir),
Options: []string{"rbind", "rw"},
},
}
specOpts = append(specOpts, oci.WithMounts(gpuMounts))

gpuDeviceFiles, err := listFilesWithPrefix("/dev", "nvidia")
if err != nil {
Expand Down Expand Up @@ -282,24 +294,6 @@ func appendTokenMounts(mounts []specs.Mount) []specs.Mount {
return append(mounts, m)
}

// appendGpuMounts appends the default mount specs for GPU drivers
func appendGpuDriverMounts(mounts []specs.Mount) []specs.Mount {
gpuMounts := []specs.Mount{
{
Type: "volume",
Source: fmt.Sprintf("%s/lib64", gpu.InstallationHostDir),
Destination: fmt.Sprintf("%s/lib64", gpu.InstallationContainerDir),
Options: []string{"rbind", "rw"},
}, {
Type: "volume",
Source: fmt.Sprintf("%s/bin", gpu.InstallationHostDir),
Destination: fmt.Sprintf("%s/bin", gpu.InstallationContainerDir),
Options: []string{"rbind", "rw"},
},
}
return append(mounts, gpuMounts...)
}

func (r *ContainerRunner) measureCELEvents(ctx context.Context) error {
if err := r.measureContainerClaims(ctx); err != nil {
return fmt.Errorf("failed to measure container claims: %v", err)
Expand Down
26 changes: 6 additions & 20 deletions launcher/internal/gpu/driverinstaller.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,7 @@ const (
installerSnapshotID = "tee-gpu-driver-installer-snapshot"
)

// SupportedGpuTypes is the list of supported gpu types with open sourced nvidia kernel modules.
var SupportedGpuTypes = []deviceinfo.GPUType{
var supportedGpuTypes = []deviceinfo.GPUType{
deviceinfo.L4,
deviceinfo.T4,
deviceinfo.A100_40GB,
Expand Down Expand Up @@ -53,13 +52,17 @@ func NewDriverInstaller(cdClient *containerd.Client, launchSpec spec.LaunchSpec,
// https://pkg.go.dev/cos.googlesource.com/cos/tools.git@v0.0.0-20241008015903-8431fe581b1f/src/cmd/cos_gpu_installer#section-readme
// README specifies docker command where this function uses containerd for launching and managing the gpu driver installer container.
func (di *DriverInstaller) InstallGPUDrivers(ctx context.Context) error {
if err := os.MkdirAll(InstallationHostDir, 0755); err != nil {
return fmt.Errorf("failed to create dir %q: %v", InstallationHostDir, err)
}

gpuType, err := deviceinfo.GetGPUTypeInfo()
if err != nil {
return fmt.Errorf("failed to get the gpu type info: %v", err)
}

if !gpuType.OpenSupported() {
return fmt.Errorf("unsupported gpu type %s, please retry with one of the supported gpu types: %v", gpuType.String(), gpu.SupportedGpuTypes)
return fmt.Errorf("unsupported gpu type %s, please retry with one of the supported gpu types: %v", gpuType.String(), supportedGpuTypes)
}

ctx = namespaces.WithNamespace(ctx, namespaces.Default)
Expand Down Expand Up @@ -132,10 +135,6 @@ func (di *DriverInstaller) InstallGPUDrivers(ctx context.Context) error {
code, _, _ := status.Result()
di.logger.Printf("Gpu driver installation task exited with status: %d\n", code)

err = remountAsExecutable(gpu.InstallationHostDir)
if err != nil {
return fmt.Errorf("failed to remount the installed drivers: %v", err)
}
return nil
}

Expand All @@ -147,16 +146,3 @@ func getInstallerImageReference() (string, error) {
installerImageRef := strings.TrimSpace(string(installerImageRefBytes))
return installerImageRef, nil
}

func remountAsExecutable(dir string) error {
if err := os.MkdirAll(dir, 0755); err != nil {
return fmt.Errorf("failed to create dir %q: %v", dir, err)
}
if err := exec.Command("mount", "--bind", dir, dir).Run(); err != nil {
return fmt.Errorf("failed to create bind mount at %q: %v", dir, err)
}
if err := exec.Command("mount", "-o", "remount,exec", dir).Run(); err != nil {
return fmt.Errorf("failed to remount %q: %v", dir, err)
}
return nil
}
9 changes: 2 additions & 7 deletions launcher/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,22 +34,17 @@ func FetchImpersonatedToken(ctx context.Context, serviceAccount string, audience

func listFilesWithPrefix(targetDir string, prefix string) ([]string, error) {
targetFiles := make([]string, 0)

err := filepath.WalkDir(targetDir, func(path string, _ os.DirEntry, err error) error {
err := filepath.WalkDir(targetDir, func(path string, d os.DirEntry, err error) error {
if err != nil {
return err
}

if strings.HasPrefix(filepath.Base(path), prefix) {
if !d.IsDir() && strings.HasPrefix(filepath.Base(path), prefix) {
targetFiles = append(targetFiles, path)
}

return nil
})

if err != nil {
return nil, fmt.Errorf("error walking directory: %v", err)
}

return targetFiles, nil
}

0 comments on commit d658692

Please sign in to comment.