Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: use nvidia-smi on musl targets #290

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions crates/rattler_virtual_packages/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,7 @@ rattler_conda_types = { version = "0.8.0", path = "../rattler_conda_types" }
thiserror = "1.0.43"
tracing = "0.1.37"
serde = { version = "1.0.171", features = ["derive"] }
regex = "1.9.1"

[target.'cfg(target_os="macos")'.dependencies]
plist = "1"

[target.'cfg(unix)'.dependencies]
regex = "1.9.1"
71 changes: 69 additions & 2 deletions crates/rattler_virtual_packages/src/cuda.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
use libloading::Symbol;
use once_cell::sync::OnceCell;
use rattler_conda_types::Version;
use std::process::Command;
use std::{
mem::MaybeUninit,
os::raw::{c_int, c_uint, c_ulong},
Expand All @@ -21,10 +22,22 @@ use std::{
pub fn cuda_version() -> Option<Version> {
static DETECTED_CUDA_VERSION: OnceCell<Option<Version>> = OnceCell::new();
DETECTED_CUDA_VERSION
.get_or_init(detect_cuda_version_via_nvml)
.get_or_init(detect_cuda_version)
.clone()
}

/// Attempts to detect the version of CUDA present in the current operating system by employing the
/// best technique available for the current environment.
pub fn detect_cuda_version() -> Option<Version> {
if cfg!(target_env = "musl") {
// Dynamically loading a library is not supported on musl so we have to fall-back to using
// the nvidia-smi command.
detect_cuda_version_via_nvidia_smi()
} else {
detect_cuda_version_via_nvml()
}
}

/// Attempts to detect the version of CUDA present in the current operating system by loading the
/// NVIDIA Management Library and querying the CUDA driver version. The method is preferred over
/// [`detect_cuda_version_via_libcuda`] because that method might fail base on environment
Expand Down Expand Up @@ -192,13 +205,67 @@ fn cuda_library_paths() -> &'static [&'static str] {
FILENAMES
}

/// Attempts to detect the version of CUDA present in the current operating system by executing the
/// "nvidia-smi" command and extracting the CUDA driver version from it.
///
/// The behavior of "nvidia-smi" depends on the environment variable `CUDA_VISIBLE_DEVICES`. If
/// users have this variable set in their environment this function will likely not return the
/// correct value. To ensure a consistent response this environment variable is unset when invoking
/// the command.
///
/// The upside of using this detection function over any of the others is that this method does not
/// dynamically load a library which might not be supported on all systems. The downside is that
/// executing a subprocess is generally slower and more prone to errors.
fn detect_cuda_version_via_nvidia_smi() -> Option<Version> {
// Invoke the "nvidia-smi" command to query the driver version that is usually installed when
// Cuda drivers are installed.
let nvidia_smi_output = Command::new("nvidia-smi")
// Display GPU or unit info
.arg("--query")
// Show unit, rather than GPU, attributes
.arg("-u")
// Produce XML output.
.arg("-x")
// The behavior of functions from `libcuda` depend on the environment variable
// `CUDA_VISIBLE_DEVICES`. If users have this variable set in their environment this
// function will likely not return the correct value. Therefor, we remove this variable
// to ensure a consistent result.
// TODO: Is this really the proper way to do it? Should we maybe clear the entire
// environment.
.env_remove("CUDA_VISIBLE_DEVICES")
.output()
.ok()?;

// Convert the output to Utf8. The conversion is lossy so it might contain some illegal
// characters. If thats the case we simply assume the version in the file also wont make sense
// during parsing.
let output = String::from_utf8_lossy(&nvidia_smi_output.stdout);
static CUDA_VERSION_RE: once_cell::sync::Lazy<regex::Regex> =
once_cell::sync::Lazy::new(|| {
regex::Regex::new("<cuda_version>(.*)<\\/cuda_version>").unwrap()
});

// Extract the version from the XML
let version_match = CUDA_VERSION_RE.captures(&output)?;
let version_str = version_match.get(1)?.as_str();

// Parse and return
Version::from_str(version_str).ok()
}

#[cfg(test)]
mod test {
use super::detect_cuda_version_via_nvml;
use super::*;

#[test]
pub fn doesnt_crash() {
let version = detect_cuda_version_via_nvml();
println!("Cuda {:?}", version);
}

#[test]
pub fn doesnt_crash_nvidia_smi() {
let version = detect_cuda_version_via_nvidia_smi();
println!("Cuda {:?}", version);
}
}