Skip to content

Commit

Permalink
patchpkg: patch python to use devbox CUDA libs (#2296)
Browse files Browse the repository at this point in the history
Automatically patch python to use any `cudaPackages.*` packages that are
in devbox.json. This will only work if the CUDA drivers are already
installed on the host system.

The patching process is:

1. When generating the patch flake, look for the system’s `libcuda.so`
   library (installed by the driver) and copy it into the flake’s
   directory.
2. Nix copies the flake’s source directory (and therefore libcuda.so)
   into the nix store when building it.
3. The flake calls `devbox patch` which adds a `DT_NEEDED` entry to the
   python binary for `libcuda.so`. It also adds the lib directories of
   any other `cudaPackages.*` packages that it finds in the
   environment.
  • Loading branch information
gcurtis authored Sep 27, 2024
1 parent 58ed80e commit 57312c0
Show file tree
Hide file tree
Showing 6 changed files with 692 additions and 30 deletions.
28 changes: 15 additions & 13 deletions internal/devpkg/package.go
Original file line number Diff line number Diff line change
Expand Up @@ -171,9 +171,7 @@ func resolve(pkg *Package) error {
if err != nil {
return err
}

// TODO savil. Check with Greg about setting the user-specified outputs
// somehow here.
parsed.Outputs = strings.Join(pkg.outputs.selectedNames, ",")

pkg.setInstallable(parsed, pkg.lockfile.ProjectDir())
return nil
Expand Down Expand Up @@ -308,7 +306,10 @@ func (p *Package) InstallableForOutput(output string) (string, error) {
// a valid flake reference parsable by ParseFlakeRef, optionally followed by an
// #attrpath and/or an ^output.
func (p *Package) FlakeInstallable() (flake.Installable, error) {
return flake.ParseInstallable(p.Raw)
if err := p.resolve(); err != nil {
return flake.Installable{}, err
}
return p.installable, nil
}

// urlForInstall is used during `nix profile install`.
Expand All @@ -322,15 +323,16 @@ func (p *Package) urlForInstall() (string, error) {
}

func (p *Package) NormalizedDevboxPackageReference() (string, error) {
if err := p.resolve(); err != nil {
installable, err := p.FlakeInstallable()
if err != nil {
return "", err
}
if p.installable.AttrPath == "" {
if installable.AttrPath == "" {
return "", nil
}
clone := p.installable
clone.AttrPath = fmt.Sprintf("legacyPackages.%s.%s", nix.System(), clone.AttrPath)
return clone.String(), nil
installable.AttrPath = fmt.Sprintf("legacyPackages.%s.%s", nix.System(), installable.AttrPath)
installable.Outputs = ""
return installable.String(), nil
}

// PackageAttributePath returns the short attribute path for a package which
Expand Down Expand Up @@ -376,19 +378,19 @@ func (p *Package) NormalizedPackageAttributePath() (string, error) {
// normalizePackageAttributePath calls nix search to find the normalized attribute
// path. It may be an expensive call (~100ms).
func (p *Package) normalizePackageAttributePath() (string, error) {
if err := p.resolve(); err != nil {
installable, err := p.FlakeInstallable()
if err != nil {
return "", err
}

query := p.installable.String()
installable.Outputs = ""
query := installable.String()
if query == "" {
query = p.Raw
}

// We prefer nix.Search over just trying to parse the package's "URL" because
// nix.Search will guarantee that the package exists for the current system.
var infos map[string]*nix.Info
var err error
if p.IsDevboxPackage && !p.IsRunX() {
// Perf optimization: For queries of the form nixpkgs/<commit>#foo, we can
// use a nix.Search cache.
Expand Down
85 changes: 84 additions & 1 deletion internal/patchpkg/builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import (
"path"
"path/filepath"
"regexp"
"strings"
)

//go:embed glibc-patch.bash
Expand All @@ -42,6 +43,10 @@ type DerivationBuilder struct {

RestoreRefs bool
bytePatches map[string][]fileSlice

// src contains the source files of the derivation. For flakes, this is
// anything in the flake.nix directory.
src *packageFS
}

// NewDerivationBuilder initializes a new DerivationBuilder from the current
Expand Down Expand Up @@ -79,6 +84,9 @@ func (d *DerivationBuilder) init() error {
return fmt.Errorf("patchpkg: can't patch gcc using %s: %v", d.Gcc, err)
}
}
if src := os.Getenv("src"); src != "" {
d.src = newPackageFS(src)
}
return nil
}

Expand All @@ -95,6 +103,11 @@ func (d *DerivationBuilder) Build(ctx context.Context, pkgStorePath string) erro
}

func (d *DerivationBuilder) build(ctx context.Context, pkg, out *packageFS) error {
// Create the derivation's $out directory.
if err := d.copyDir(out, "."); err != nil {
return err
}

if d.RestoreRefs {
if err := d.restoreMissingRefs(ctx, pkg); err != nil {
// Don't break the flake build if we're unable to
Expand All @@ -103,12 +116,19 @@ func (d *DerivationBuilder) build(ctx context.Context, pkg, out *packageFS) erro
slog.ErrorContext(ctx, "unable to restore all removed refs", "err", err)
}
}
if err := d.findCUDA(ctx, out); err != nil {
slog.ErrorContext(ctx, "unable to patch CUDA libraries", "err", err)
}

var err error
for path, entry := range allFiles(pkg, ".") {
if ctx.Err() != nil {
return ctx.Err()
}
if path == "." {
// Skip the $out directory - we already created it.
continue
}

switch {
case entry.IsDir():
Expand Down Expand Up @@ -167,7 +187,7 @@ func (d *DerivationBuilder) copyDir(out *packageFS, path string) error {
if err != nil {
return err
}
return os.Mkdir(path, 0o777)
return os.MkdirAll(path, 0o777)
}

func (d *DerivationBuilder) copyFile(ctx context.Context, pkg, out *packageFS, path string) error {
Expand Down Expand Up @@ -302,6 +322,69 @@ func (d *DerivationBuilder) findRemovedRefs(ctx context.Context, pkg *packageFS)
return refs, nil
}

func (d *DerivationBuilder) findCUDA(ctx context.Context, out *packageFS) error {
if d.src == nil {
return fmt.Errorf("patch flake didn't set $src to the path to its source tree")
}

glob, err := fs.Glob(d.src, "lib/libcuda.so*")
if err != nil {
return fmt.Errorf("glob system libraries: %v", err)
}
if len(glob) != 0 {
err := d.copyDir(out, "lib")
if err != nil {
return fmt.Errorf("copy system library: %v", err)
}
}
for _, lib := range glob {
slog.DebugContext(ctx, "found system CUDA library in flake", "path", lib)

err := d.copyFile(ctx, d.src, out, lib)
if err != nil {
return fmt.Errorf("copy system library: %v", err)
}
need, err := out.OSPath(lib)
if err != nil {
return fmt.Errorf("get absolute path to library: %v", err)
}
d.glibcPatcher.needed = append(d.glibcPatcher.needed, need)

slog.DebugContext(ctx, "added DT_NEEDED entry for system CUDA library", "path", need)
}

slog.DebugContext(ctx, "looking for nix libraries in $patchDependencies")
deps := os.Getenv("patchDependencies")
if strings.TrimSpace(deps) == "" {
slog.DebugContext(ctx, "$patchDependencies is empty")
return nil
}
for _, pkg := range strings.Split(deps, " ") {
slog.DebugContext(ctx, "checking for nix libraries in package", "pkg", pkg)

pkgFS := newPackageFS(pkg)
libs, err := fs.Glob(pkgFS, "lib*/*.so*")
if err != nil {
return fmt.Errorf("glob nix package libraries: %v", err)
}

sonameRegexp := regexp.MustCompile(`(^|/).+\.so\.\d+`)
for _, lib := range libs {
if !sonameRegexp.MatchString(lib) {
continue
}
need, err := pkgFS.OSPath(lib)
if err != nil {
return fmt.Errorf("get absolute path to nix package library: %v", err)
}
d.glibcPatcher.needed = append(d.glibcPatcher.needed, need)

slog.DebugContext(ctx, "added DT_NEEDED entry for nix library", "path", need)
}
}
return nil
}

// packageFS is the tree of files for a package in the Nix store.
type packageFS struct {
fs.FS
Expand Down
96 changes: 96 additions & 0 deletions internal/patchpkg/search.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@ import (
"fmt"
"io"
"io/fs"
"iter"
"os"
"path/filepath"
"regexp"
"strings"
"sync"
Expand Down Expand Up @@ -80,3 +82,97 @@ func searchEnv(re *regexp.Regexp) string {
}
return ""
}

// SystemCUDALibraries returns an iterator over the system CUDA library paths.
// It yields them in priority order, where the first path is the most likely to
// be the correct version.
var SystemCUDALibraries iter.Seq[string] = func(yield func(string) bool) {
// Quick overview of Unix-like shared object versioning.
//
// Libraries have 3 different names (using libcuda as an example):
//
// 1. libcuda.so - the "linker name". Typically a symlink pointing to
// the soname. The compiler looks for this name.
// 2. libcuda.so.1 - the "soname". Typically a symlink pointing to the
// real name. The dynamic linker looks for this name.
// 3. libcuda.so.550.107.02 - the "real name". The actual ELF shared
// library. Usually never referred to directly because that would
// make versioning hard.
//
// Because we don't know what version of CUDA the user's program
// actually needs, we're going to try to find linker names (libcuda.so)
// and trust that the system is pointing it to the correct version.
// We'll fall back to sonames (libcuda.so.1) that we find if none of the
// linker names work.

// Common direct paths to try first.
linkerNames := []string{
"/usr/lib/x86_64-linux-gnu/libcuda.so", // Debian
"/usr/lib64/libcuda.so", // Red Hat
"/usr/lib/libcuda.so",
}
for _, path := range linkerNames {
// Return what the link is pointing to because the dynamic
// linker will want libcuda.so.1, not libcuda.so.
soname, err := os.Readlink(path)
if err != nil {
continue
}
if filepath.IsLocal(soname) {
soname = filepath.Join(filepath.Dir(path), soname)
}
if !yield(soname) {
return
}
}

// Directories to recursively search.
prefixes := []string{
"/usr/lib",
"/usr/lib64",
"/lib",
"/lib64",
"/usr/local/lib",
"/usr/local/lib64",
"/opt/cuda",
"/opt/nvidia",
"/usr/local/cuda",
"/usr/local/nvidia",
}
sonameRegex := regexp.MustCompile(`^libcuda\.so\.\d+$`)
var sonames []string
for _, path := range prefixes {
_ = filepath.WalkDir(path, func(path string, entry fs.DirEntry, err error) error {
if err != nil {
return nil
}
if entry.Name() == "libcuda.so" && isSymlink(entry.Type()) {
soname, err := os.Readlink(path)
if err != nil {
return nil
}
if filepath.IsLocal(soname) {
soname = filepath.Join(filepath.Dir(path), soname)
}
if !yield(soname) {
return filepath.SkipAll
}
}

// Save potential soname matches for later after we've
// exhausted all the potential linker names.
if sonameRegex.MatchString(entry.Name()) {
sonames = append(sonames, entry.Name())
}
return nil
})
}

// We didn't find any symlinks named libcuda.so. Fall back to trying any
// sonames (e.g., libcuda.so.1) that we found.
for _, path := range sonames {
if !yield(path) {
return
}
}
}
Loading

0 comments on commit 57312c0

Please sign in to comment.