Skip to content

Commit

Permalink
feat: use nvidia-smi on musl targets (#290)
Browse files Browse the repository at this point in the history
  • Loading branch information
baszalmstra authored Aug 24, 2023
1 parent fc88e12 commit 2d8937f
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 5 deletions.
4 changes: 1 addition & 3 deletions crates/rattler_virtual_packages/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,7 @@ rattler_conda_types = { version = "0.8.0", path = "../rattler_conda_types" }
thiserror = "1.0.43"
tracing = "0.1.37"
serde = { version = "1.0.171", features = ["derive"] }
regex = "1.9.1"

[target.'cfg(target_os="macos")'.dependencies]
plist = "1"

[target.'cfg(unix)'.dependencies]
regex = "1.9.1"
71 changes: 69 additions & 2 deletions crates/rattler_virtual_packages/src/cuda.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
use libloading::Symbol;
use once_cell::sync::OnceCell;
use rattler_conda_types::Version;
use std::process::Command;
use std::{
mem::MaybeUninit,
os::raw::{c_int, c_uint, c_ulong},
Expand All @@ -21,10 +22,22 @@ use std::{
pub fn cuda_version() -> Option<Version> {
static DETECTED_CUDA_VERSION: OnceCell<Option<Version>> = OnceCell::new();
DETECTED_CUDA_VERSION
.get_or_init(detect_cuda_version_via_nvml)
.get_or_init(detect_cuda_version)
.clone()
}

/// Attempts to detect the version of CUDA present in the current operating system by employing the
/// best technique available for the current environment.
pub fn detect_cuda_version() -> Option<Version> {
if cfg!(target_env = "musl") {
// Dynamically loading a library is not supported on musl so we have to fall-back to using
// the nvidia-smi command.
detect_cuda_version_via_nvidia_smi()
} else {
detect_cuda_version_via_nvml()
}
}

/// Attempts to detect the version of CUDA present in the current operating system by loading the
/// NVIDIA Management Library and querying the CUDA driver version. The method is preferred over
/// [`detect_cuda_version_via_libcuda`] because that method might fail base on environment
Expand Down Expand Up @@ -192,13 +205,67 @@ fn cuda_library_paths() -> &'static [&'static str] {
FILENAMES
}

/// Attempts to detect the version of CUDA present in the current operating system by executing the
/// "nvidia-smi" command and extracting the CUDA driver version from it.
///
/// The behavior of "nvidia-smi" depends on the environment variable `CUDA_VISIBLE_DEVICES`. If
/// users have this variable set in their environment this function will likely not return the
/// correct value. To ensure a consistent response this environment variable is unset when invoking
/// the command.
///
/// The upside of using this detection function over any of the others is that this method does not
/// dynamically load a library which might not be supported on all systems. The downside is that
/// executing a subprocess is generally slower and more prone to errors.
fn detect_cuda_version_via_nvidia_smi() -> Option<Version> {
// Invoke the "nvidia-smi" command to query the driver version that is usually installed when
// Cuda drivers are installed.
let nvidia_smi_output = Command::new("nvidia-smi")
// Display GPU or unit info
.arg("--query")
// Show unit, rather than GPU, attributes
.arg("-u")
// Produce XML output.
.arg("-x")
// The behavior of functions from `libcuda` depend on the environment variable
// `CUDA_VISIBLE_DEVICES`. If users have this variable set in their environment this
// function will likely not return the correct value. Therefor, we remove this variable
// to ensure a consistent result.
// TODO: Is this really the proper way to do it? Should we maybe clear the entire
// environment.
.env_remove("CUDA_VISIBLE_DEVICES")
.output()
.ok()?;

// Convert the output to Utf8. The conversion is lossy so it might contain some illegal
// characters. If thats the case we simply assume the version in the file also wont make sense
// during parsing.
let output = String::from_utf8_lossy(&nvidia_smi_output.stdout);
static CUDA_VERSION_RE: once_cell::sync::Lazy<regex::Regex> =
once_cell::sync::Lazy::new(|| {
regex::Regex::new("<cuda_version>(.*)<\\/cuda_version>").unwrap()
});

// Extract the version from the XML
let version_match = CUDA_VERSION_RE.captures(&output)?;
let version_str = version_match.get(1)?.as_str();

// Parse and return
Version::from_str(version_str).ok()
}

#[cfg(test)]
mod test {
use super::detect_cuda_version_via_nvml;
use super::*;

#[test]
pub fn doesnt_crash() {
let version = detect_cuda_version_via_nvml();
println!("Cuda {:?}", version);
}

#[test]
pub fn doesnt_crash_nvidia_smi() {
let version = detect_cuda_version_via_nvidia_smi();
println!("Cuda {:?}", version);
}
}

0 comments on commit 2d8937f

Please sign in to comment.