Skip to content

Commit

Permalink
feat: expose ortsys macro; make api return a reference
Browse files Browse the repository at this point in the history
  • Loading branch information
decahedron1 committed Nov 4, 2024
1 parent 516db5f commit 552727e
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 27 deletions.
40 changes: 19 additions & 21 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#![doc(html_logo_url = "https://raw.githubusercontent.com/pykeio/ort/v2/docs/icon.png")]
#![cfg_attr(docsrs, feature(doc_cfg))]
#![allow(clippy::tabs_in_doc_comments, clippy::arc_with_non_send_sync)]
#![allow(clippy::macro_metavars_in_unsafe)]
#![warn(clippy::unwrap_used)]

//! <div align=center>
Expand Down Expand Up @@ -165,23 +166,23 @@ pub fn info() -> &'static str {
/// ```
/// # use std::ffi::CStr;
/// # fn main() -> ort::Result<()> {
/// let api = ort::api().as_ptr();
/// let build_info = unsafe { CStr::from_ptr((*api).GetBuildInfoString.unwrap()()) };
/// let api = ort::api();
/// let build_info = unsafe { CStr::from_ptr(api.GetBuildInfoString.unwrap()()) };
/// println!("{}", build_info.to_string_lossy());
/// // ORT Build Info: git-branch=HEAD, git-commit-id=4573740, build type=Release, cmake cxx flags: /DWIN32 /D_WINDOWS /EHsc /EHsc /wd26812 -DEIGEN_HAS_C99_MATH -DCPUINFO_SUPPORTED
/// # Ok(())
/// # }
/// ```
///
/// For the full list of ONNX Runtime APIs, consult the [`ort_sys::OrtApi`] struct and the [ONNX Runtime C API](https://onnxruntime.ai/docs/api/c/struct_ort_api.html).
pub fn api() -> NonNull<ort_sys::OrtApi> {
pub fn api() -> &'static ort_sys::OrtApi {
struct ApiPointer(NonNull<ort_sys::OrtApi>);
unsafe impl Send for ApiPointer {}
unsafe impl Sync for ApiPointer {}

static G_ORT_API: OnceLock<ApiPointer> = OnceLock::new();

G_ORT_API
let ptr = G_ORT_API
.get_or_init(|| {
#[cfg(feature = "load-dynamic")]
unsafe {
Expand Down Expand Up @@ -227,55 +228,52 @@ pub fn api() -> NonNull<ort_sys::OrtApi> {
ApiPointer(NonNull::new(api.cast_mut()).expect("Failed to initialize ORT API"))
}
})
.0
.0;
unsafe { ptr.as_ref() }
}

#[macro_export]
macro_rules! ortsys {
($method:ident) => {
$crate::api().as_ref().$method.unwrap_or_else(|| unreachable!(concat!("Method `", stringify!($method), "` is null")))
};
(unsafe $method:ident) => {
unsafe { $crate::api().as_ref().$method.unwrap_or_else(|| unreachable!(concat!("Method `", stringify!($method), "` is null"))) }
$crate::api().$method.unwrap_or_else(|| unreachable!(concat!("Method `", stringify!($method), "` is null")))
};
($method:ident($($n:expr),+ $(,)?)) => {
$crate::api().as_ref().$method.unwrap_or_else(|| unreachable!(concat!("Method `", stringify!($method), "` is null")))($($n),+)
$crate::api().$method.unwrap_or_else(|| unreachable!(concat!("Method `", stringify!($method), "` is null")))($($n),+)
};
(unsafe $method:ident($($n:expr),+ $(,)?)) => {
unsafe { $crate::api().as_ref().$method.unwrap_or_else(|| unreachable!(concat!("Method `", stringify!($method), "` is null")))($($n),+) }
unsafe { $crate::api().$method.unwrap_or_else(|| unreachable!(concat!("Method `", stringify!($method), "` is null")))($($n),+) }
};
($method:ident($($n:expr),+ $(,)?).expect($e:expr)) => {
$crate::error::status_to_result($crate::api().as_ref().$method.unwrap_or_else(|| unreachable!(concat!("Method `", stringify!($method), "` is null")))($($n),+)).expect($e)
$crate::error::status_to_result($crate::api().$method.unwrap_or_else(|| unreachable!(concat!("Method `", stringify!($method), "` is null")))($($n),+)).expect($e)
};
(unsafe $method:ident($($n:expr),+ $(,)?).expect($e:expr)) => {
$crate::error::status_to_result(unsafe { $crate::api().as_ref().$method.unwrap_or_else(|| unreachable!(concat!("Method `", stringify!($method), "` is null")))($($n),+) }).expect($e)
$crate::error::status_to_result(unsafe { $crate::api().$method.unwrap_or_else(|| unreachable!(concat!("Method `", stringify!($method), "` is null")))($($n),+) }).expect($e)
};
($method:ident($($n:expr),+ $(,)?); nonNull($($check:expr),+ $(,)?)$(;)?) => {
$crate::api().as_ref().$method.unwrap_or_else(|| unreachable!(concat!("Method `", stringify!($method), "` is null")))($($n),+);
$crate::api().$method.unwrap_or_else(|| unreachable!(concat!("Method `", stringify!($method), "` is null")))($($n),+);
$($crate::error::assert_non_null_pointer($check, stringify!($method))?;)+
};
(unsafe $method:ident($($n:expr),+ $(,)?); nonNull($($check:expr),+ $(,)?)$(;)?) => {{
let _x = unsafe { $crate::api().as_ref().$method.unwrap_or_else(|| unreachable!(concat!("Method `", stringify!($method), "` is null")))($($n),+) };
let _x = unsafe { $crate::api().$method.unwrap_or_else(|| unreachable!(concat!("Method `", stringify!($method), "` is null")))($($n),+) };
$($crate::error::assert_non_null_pointer($check, stringify!($method)).unwrap();)+
_x
}};
($method:ident($($n:expr),+ $(,)?)?) => {
$crate::error::status_to_result($crate::api().as_ref().$method.unwrap_or_else(|| unreachable!(concat!("Method `", stringify!($method), "` is null")))($($n),+))?;
$crate::error::status_to_result($crate::api().$method.unwrap_or_else(|| unreachable!(concat!("Method `", stringify!($method), "` is null")))($($n),+))?;
};
(unsafe $method:ident($($n:expr),+ $(,)?)?) => {
$crate::error::status_to_result(unsafe { $crate::api().as_ref().$method.unwrap_or_else(|| unreachable!(concat!("Method `", stringify!($method), "` is null")))($($n),+) })?;
$crate::error::status_to_result(unsafe { $crate::api().$method.unwrap_or_else(|| unreachable!(concat!("Method `", stringify!($method), "` is null")))($($n),+) })?;
};
($method:ident($($n:expr),+ $(,)?)?; nonNull($($check:expr),+ $(,)?)$(;)?) => {
$crate::error::status_to_result($crate::api().as_ref().$method.unwrap_or_else(|| unreachable!(concat!("Method `", stringify!($method), "` is null")))($($n),+))?;
$crate::error::status_to_result($crate::api().$method.unwrap_or_else(|| unreachable!(concat!("Method `", stringify!($method), "` is null")))($($n),+))?;
$($crate::error::assert_non_null_pointer($check, stringify!($method))?;)+
};
(unsafe $method:ident($($n:expr),+ $(,)?)?; nonNull($($check:expr),+ $(,)?)$(;)?) => {{
$crate::error::status_to_result(unsafe { $crate::api().as_ref().$method.unwrap_or_else(|| unreachable!(concat!("Method `", stringify!($method), "` is null")))($($n),+) })?;
$crate::error::status_to_result(unsafe { $crate::api().$method.unwrap_or_else(|| unreachable!(concat!("Method `", stringify!($method), "` is null")))($($n),+) })?;
$($crate::error::assert_non_null_pointer($check, stringify!($method))?;)+
}};
}

pub(crate) use ortsys;

pub(crate) fn char_p_to_string(raw: *const c_char) -> Result<String> {
if raw.is_null() {
return Ok(String::new());
Expand Down
12 changes: 6 additions & 6 deletions src/session/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -505,12 +505,12 @@ mod dangerous {
use super::*;

pub(super) fn extract_inputs_count(session_ptr: NonNull<ort_sys::OrtSession>) -> Result<usize> {
let f = ortsys![unsafe SessionGetInputCount];
let f = ortsys![SessionGetInputCount];
extract_io_count(f, session_ptr)
}

pub(super) fn extract_outputs_count(session_ptr: NonNull<ort_sys::OrtSession>) -> Result<usize> {
let f = ortsys![unsafe SessionGetOutputCount];
let f = ortsys![SessionGetOutputCount];
extract_io_count(f, session_ptr)
}

Expand All @@ -525,12 +525,12 @@ mod dangerous {
}

fn extract_input_name(session_ptr: NonNull<ort_sys::OrtSession>, allocator: &Allocator, i: usize) -> Result<String> {
let f = ortsys![unsafe SessionGetInputName];
let f = ortsys![SessionGetInputName];
extract_io_name(f, session_ptr, allocator, i)
}

fn extract_output_name(session_ptr: NonNull<ort_sys::OrtSession>, allocator: &Allocator, i: usize) -> Result<String> {
let f = ortsys![unsafe SessionGetOutputName];
let f = ortsys![SessionGetOutputName];
extract_io_name(f, session_ptr, allocator, i)
}

Expand Down Expand Up @@ -568,14 +568,14 @@ mod dangerous {

pub(super) fn extract_input(session_ptr: NonNull<ort_sys::OrtSession>, allocator: &Allocator, i: usize) -> Result<Input> {
let input_name = extract_input_name(session_ptr, allocator, i)?;
let f = ortsys![unsafe SessionGetInputTypeInfo];
let f = ortsys![SessionGetInputTypeInfo];
let input_type = extract_io(f, session_ptr, i)?;
Ok(Input { name: input_name, input_type })
}

pub(super) fn extract_output(session_ptr: NonNull<ort_sys::OrtSession>, allocator: &Allocator, i: usize) -> Result<Output> {
let output_name = extract_output_name(session_ptr, allocator, i)?;
let f = ortsys![unsafe SessionGetOutputTypeInfo];
let f = ortsys![SessionGetOutputTypeInfo];
let output_type = extract_io(f, session_ptr, i)?;
Ok(Output { name: output_name, output_type })
}
Expand Down

0 comments on commit 552727e

Please sign in to comment.