Skip to content

Commit

Permalink
Merge pull request #200 from kevinaboos/windows_cuda_support
Browse files Browse the repository at this point in the history
Windows cuda support
  • Loading branch information
jmbejar authored Aug 15, 2024
2 parents 245d104 + 309d2aa commit df682ca
Showing 1 changed file with 54 additions and 0 deletions.
54 changes: 54 additions & 0 deletions moxin-runner/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -144,9 +144,25 @@ const ENV_LD_LIBRARY_PATH: &str = "LD_LIBRARY_PATH";
#[cfg(target_os = "macos")]
const ENV_DYLD_FALLBACK_LIBRARY_PATH: &str = "DYLD_FALLBACK_LIBRARY_PATH";


/// Returns the URL of the WASI-NN plugin that should be downloaded, and its inner directory name.
///
/// Note that this is only used on Windows, because the install_v2.sh script handles it on Linux.
///
/// The plugin selection follows this priority order of hardware features:
/// 1. The CUDA build, if CUDA V12 is installed.
/// 2. The default AVX512 build, if on x86_64 and AVX512F is supported.
/// 3. Otherwise, the noavx build (which itself still requires SSE4.2 or SSE4a).
#[cfg(windows)]
fn wasmedge_wasi_nn_plugin_url() -> (&'static str, &'static str) {
// Currently, WasmEdge's b3499 release only provides a CUDA 12 build for Windows.
if matches!(get_cuda_version(), Some(CudaVersion::V12)) {
return (
"https://github.com/second-state/WASI-NN-GGML-PLUGIN-REGISTRY/releases/download/b3499/WasmEdge-plugin-wasi_nn-ggml-cuda-0.14.0-windows_x86_64.zip",
"WasmEdge-plugin-wasi_nn-ggml-cuda-0.14.0-windows_x86_64",
);
}

#[cfg(target_arch = "x86_64")]
if is_x86_feature_detected!("avx512f") {
return (
Expand Down Expand Up @@ -452,6 +468,44 @@ fn set_env_vars<P: AsRef<Path>>(wasmedge_root_dir_path: &P) {
std::env::set_var(ENV_WASMEDGE_PLUGIN_PATH, wasmedge_root_dir_path.as_ref());
}


/// Versions of CUDA that WasmEdge supports.
enum CudaVersion {
/// CUDA Version 12
V12,
/// CUDA Version 11
V11,
}

/// Attempts to discover what version of CUDA is locally installed, if any.
///
/// This function first runs `nvcc --version` on both Linux and Windows,
/// and if that fails, it will try `/usr/local/cuda/bin/nvcc --version` on Linux only.
fn get_cuda_version() -> Option<CudaVersion> {
let mut output = Command::new("nvcc")
.arg("--version")
.output();

#[cfg(target_os = "linux")] {
output = output.or_else(|_|
Command::new("/usr/local/cuda/bin/nvcc")
.arg("--version")
.output()
);
}

let output = output.ok()?;
let output = String::from_utf8_lossy(&output.stdout);
if output.contains("V12") {
Some(CudaVersion::V12)
} else if output.contains("V11") {
Some(CudaVersion::V11)
} else {
None
}
}


/// Runs the `_moxin_app` binary, which must be located in the same directory as this moxin-runner binary.
///
/// An optional path to the directory containing the main WasmEdge dylib can be provided,
Expand Down

0 comments on commit df682ca

Please sign in to comment.