From 2d8937fb46e3c5ca89c6f2801e1548527dea5ecf Mon Sep 17 00:00:00 2001 From: Bas Zalmstra Date: Thu, 24 Aug 2023 12:07:19 +0200 Subject: [PATCH] feat: use nvidia-smi on musl targets (#290) --- crates/rattler_virtual_packages/Cargo.toml | 4 +- crates/rattler_virtual_packages/src/cuda.rs | 71 ++++++++++++++++++++- 2 files changed, 70 insertions(+), 5 deletions(-) diff --git a/crates/rattler_virtual_packages/Cargo.toml b/crates/rattler_virtual_packages/Cargo.toml index 28d0e76df..f5969008d 100644 --- a/crates/rattler_virtual_packages/Cargo.toml +++ b/crates/rattler_virtual_packages/Cargo.toml @@ -19,9 +19,7 @@ rattler_conda_types = { version = "0.8.0", path = "../rattler_conda_types" } thiserror = "1.0.43" tracing = "0.1.37" serde = { version = "1.0.171", features = ["derive"] } +regex = "1.9.1" [target.'cfg(target_os="macos")'.dependencies] plist = "1" - -[target.'cfg(unix)'.dependencies] -regex = "1.9.1" diff --git a/crates/rattler_virtual_packages/src/cuda.rs b/crates/rattler_virtual_packages/src/cuda.rs index e6ba92ff5..232248e1c 100644 --- a/crates/rattler_virtual_packages/src/cuda.rs +++ b/crates/rattler_virtual_packages/src/cuda.rs @@ -11,6 +11,7 @@ use libloading::Symbol; use once_cell::sync::OnceCell; use rattler_conda_types::Version; +use std::process::Command; use std::{ mem::MaybeUninit, os::raw::{c_int, c_uint, c_ulong}, @@ -21,10 +22,22 @@ use std::{ pub fn cuda_version() -> Option { static DETECTED_CUDA_VERSION: OnceCell> = OnceCell::new(); DETECTED_CUDA_VERSION - .get_or_init(detect_cuda_version_via_nvml) + .get_or_init(detect_cuda_version) .clone() } +/// Attempts to detect the version of CUDA present in the current operating system by employing the +/// best technique available for the current environment. +pub fn detect_cuda_version() -> Option { + if cfg!(target_env = "musl") { + // Dynamically loading a library is not supported on musl so we have to fall-back to using + // the nvidia-smi command. + detect_cuda_version_via_nvidia_smi() + } else { + detect_cuda_version_via_nvml() + } +} + /// Attempts to detect the version of CUDA present in the current operating system by loading the /// NVIDIA Management Library and querying the CUDA driver version. The method is preferred over /// [`detect_cuda_version_via_libcuda`] because that method might fail base on environment @@ -192,13 +205,67 @@ fn cuda_library_paths() -> &'static [&'static str] { FILENAMES } +/// Attempts to detect the version of CUDA present in the current operating system by executing the +/// "nvidia-smi" command and extracting the CUDA driver version from it. +/// +/// The behavior of "nvidia-smi" depends on the environment variable `CUDA_VISIBLE_DEVICES`. If +/// users have this variable set in their environment this function will likely not return the +/// correct value. To ensure a consistent response this environment variable is unset when invoking +/// the command. +/// +/// The upside of using this detection function over any of the others is that this method does not +/// dynamically load a library which might not be supported on all systems. The downside is that +/// executing a subprocess is generally slower and more prone to errors. +fn detect_cuda_version_via_nvidia_smi() -> Option { + // Invoke the "nvidia-smi" command to query the driver version that is usually installed when + // Cuda drivers are installed. + let nvidia_smi_output = Command::new("nvidia-smi") + // Display GPU or unit info + .arg("--query") + // Show unit, rather than GPU, attributes + .arg("-u") + // Produce XML output. + .arg("-x") + // The behavior of functions from `libcuda` depend on the environment variable + // `CUDA_VISIBLE_DEVICES`. If users have this variable set in their environment this + // function will likely not return the correct value. Therefor, we remove this variable + // to ensure a consistent result. + // TODO: Is this really the proper way to do it? Should we maybe clear the entire + // environment. + .env_remove("CUDA_VISIBLE_DEVICES") + .output() + .ok()?; + + // Convert the output to Utf8. The conversion is lossy so it might contain some illegal + // characters. If thats the case we simply assume the version in the file also wont make sense + // during parsing. + let output = String::from_utf8_lossy(&nvidia_smi_output.stdout); + static CUDA_VERSION_RE: once_cell::sync::Lazy = + once_cell::sync::Lazy::new(|| { + regex::Regex::new("(.*)<\\/cuda_version>").unwrap() + }); + + // Extract the version from the XML + let version_match = CUDA_VERSION_RE.captures(&output)?; + let version_str = version_match.get(1)?.as_str(); + + // Parse and return + Version::from_str(version_str).ok() +} + #[cfg(test)] mod test { - use super::detect_cuda_version_via_nvml; + use super::*; #[test] pub fn doesnt_crash() { let version = detect_cuda_version_via_nvml(); println!("Cuda {:?}", version); } + + #[test] + pub fn doesnt_crash_nvidia_smi() { + let version = detect_cuda_version_via_nvidia_smi(); + println!("Cuda {:?}", version); + } }