Skip to content

Commit

Permalink
feat: get GPU devices via FFI
Browse files Browse the repository at this point in the history
A new FFI call (`util::get_gpu_devices()`) is introduced. It returns
the available GPUs as an array of strings.
  • Loading branch information
vmx committed Dec 5, 2019
1 parent 6d9e800 commit 48885f7
Show file tree
Hide file tree
Showing 5 changed files with 87 additions and 0 deletions.
5 changes: 5 additions & 0 deletions rust/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,14 @@ fil_logger = "0.1.0"
rand = "0.7"
rayon = "1"
anyhow = "1.0.23"
bellperson = "0.4.4"

[build-dependencies]
cbindgen = "= 0.10.0"

[dev-dependencies]
tempfile = "3.0.8"

[features]
default = ["gpu"]
gpu = ["bellperson/gpu"]
1 change: 1 addition & 0 deletions rust/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@ extern crate log;

pub mod bls;
pub mod proofs;
pub mod util;
47 changes: 47 additions & 0 deletions rust/src/util/api.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
use std::ffi::CString;

use bellperson::GPU_NVIDIA_DEVICES;
use ffi_toolkit::{catch_panic_response, raw_ptr};

use super::types::GpuDeviceResponse;

/// Returns an array of strings containing the device names that can be used.
#[no_mangle]
pub unsafe extern "C" fn get_gpu_devices() -> *mut GpuDeviceResponse {
catch_panic_response(|| {
let devices: Vec<*const i8> = GPU_NVIDIA_DEVICES
.iter()
.map(|device| {
let name = device.name().unwrap_or("Unknown".to_string());
CString::new(&name[..]).unwrap().as_ptr()
})
.collect();
let mut response = GpuDeviceResponse::default();
response.devices_len = devices.len();
response.devices_ptr = devices.as_ptr();

raw_ptr(response)
})
}

#[cfg(test)]
mod tests {
use std::ffi::CStr;
use std::slice::from_raw_parts;

use crate::util::api::get_gpu_devices;
use crate::util::types::destroy_gpu_device_response;

#[test]
fn test_get_gpu_devices() {
unsafe {
let resp = get_gpu_devices();
let devices: Vec<&str> = from_raw_parts((*resp).devices_ptr, (*resp).devices_len)
.iter()
.map(|name_ptr| CStr::from_ptr(*name_ptr).to_str().unwrap())
.collect();
assert_eq!(devices.len(), (*resp).devices_len);
destroy_gpu_device_response(resp);
}
}
}
2 changes: 2 additions & 0 deletions rust/src/util/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
pub mod api;
pub mod types;
32 changes: 32 additions & 0 deletions rust/src/util/types.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
use std::ptr;

use drop_struct_macro_derive::DropStructMacro;
// `CodeAndMessage` is the trait implemented by `code_and_message_impl
use ffi_toolkit::{code_and_message_impl, free_c_str, CodeAndMessage, FCPResponseStatus};

#[repr(C)]
#[derive(DropStructMacro)]
pub struct GpuDeviceResponse {
pub status_code: FCPResponseStatus,
pub error_msg: *const libc::c_char,
pub devices_len: libc::size_t,
pub devices_ptr: *const *const i8,
}

impl Default for GpuDeviceResponse {
fn default() -> Self {
Self {
error_msg: ptr::null(),
status_code: FCPResponseStatus::FCPNoError,
devices_len: 0,
devices_ptr: ptr::null(),
}
}
}

code_and_message_impl!(GpuDeviceResponse);

#[no_mangle]
pub unsafe extern "C" fn destroy_gpu_device_response(ptr: *mut GpuDeviceResponse) {
let _ = Box::from_raw(ptr);
}

0 comments on commit 48885f7

Please sign in to comment.