Skip to content

Commit

Permalink
refactor: make tracing optional
Browse files Browse the repository at this point in the history
  • Loading branch information
decahedron1 committed Dec 28, 2024
1 parent 68b5e24 commit 988e92d
Show file tree
Hide file tree
Showing 10 changed files with 135 additions and 61 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ libloading = { version = "0.8", optional = true }

ureq = { version = "2.1", optional = true, default-features = false, features = [ "tls" ] }
sha2 = { version = "0.10", optional = true }
tracing = { version = "0.1", default-features = false, features = [ "std" ] }
tracing = { version = "0.1", optional = true, default-features = false, features = [ "std" ] }
half = { version = "2.1", optional = true }

[dev-dependencies]
Expand Down
83 changes: 37 additions & 46 deletions src/environment.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,12 @@
use std::{
any::Any,
ffi::{self, CStr, CString},
ffi::CString,
os::raw::c_void,
ptr::{self, NonNull},
sync::{Arc, RwLock}
};

use ort_sys::c_char;
use tracing::{Level, debug};

#[cfg(feature = "load-dynamic")]
use crate::G_ORT_DYLIB_PATH;
use crate::{AsPointer, error::Result, execution_providers::ExecutionProviderDispatch, ortsys};
Expand Down Expand Up @@ -66,7 +63,7 @@ impl AsPointer for Environment {

impl Drop for Environment {
fn drop(&mut self) {
debug!(ptr = ?self.ptr(), "Releasing environment");
crate::debug!(ptr = ?self.ptr(), "Releasing environment");
ortsys![unsafe ReleaseEnv(self.ptr_mut())];
}
}
Expand All @@ -81,7 +78,7 @@ pub fn get_environment() -> Result<Arc<Environment>> {
// drop our read lock so we dont deadlock when `commit` takes a write lock
drop(env);

debug!("Environment not yet initialized, creating a new one");
crate::debug!("Environment not yet initialized, creating a new one");
Ok(EnvironmentBuilder::new().commit()?)
}
}
Expand Down Expand Up @@ -191,11 +188,13 @@ pub(crate) unsafe extern "system" fn thread_create<T: ThreadManager + Any>(
.cast_const()
.cast::<ort_sys::OrtCustomHandleType>(),
Ok(Err(e)) => {
tracing::error!("Failed to create thread using manager: {e}");
crate::error!("Failed to create thread using manager: {e}");
let _ = e;
ptr::null()
}
Err(e) => {
tracing::error!("Thread manager panicked: {e:?}");
crate::error!("Thread manager panicked: {e:?}");
let _ = e;
ptr::null()
}
}
Expand All @@ -204,7 +203,8 @@ pub(crate) unsafe extern "system" fn thread_create<T: ThreadManager + Any>(
pub(crate) unsafe extern "system" fn thread_join<T: ThreadManager + Any>(ort_custom_thread_handle: ort_sys::OrtCustomThreadHandle) {
let handle = Box::from_raw(ort_custom_thread_handle.cast_mut().cast::<<T as ThreadManager>::Thread>());
if let Err(e) = <T as ThreadManager>::join(*handle) {
tracing::error!("Failed to join thread using manager: {e}");
crate::error!("Failed to join thread using manager: {e}");
let _ = e;
}
}

Expand Down Expand Up @@ -279,43 +279,61 @@ impl EnvironmentBuilder {
pub fn commit(self) -> Result<Arc<Environment>> {
let (env_ptr, thread_manager, has_global_threadpool) = if let Some(mut thread_pool_options) = self.global_thread_pool_options {
let mut env_ptr: *mut ort_sys::OrtEnv = std::ptr::null_mut();
let logging_function: ort_sys::OrtLoggingFunction = Some(custom_logger);
let logger_param: *mut std::ffi::c_void = std::ptr::null_mut();
let cname = CString::new(self.name.clone()).unwrap_or_else(|_| unreachable!());

#[cfg(feature = "tracing")]
ortsys![
unsafe CreateEnvWithCustomLoggerAndGlobalThreadPools(
logging_function,
logger_param,
Some(crate::logging::custom_logger),
ptr::null_mut(),
ort_sys::OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE,
cname.as_ptr(),
thread_pool_options.ptr(),
&mut env_ptr
)?;
nonNull(env_ptr)
];
#[cfg(not(feature = "tracing"))]
ortsys![
unsafe CreateEnvWithGlobalThreadPools(
crate::logging::default_log_level(),
cname.as_ptr(),
thread_pool_options.ptr(),
&mut env_ptr
)?;
nonNull(env_ptr)
];

let thread_manager = thread_pool_options.thread_manager.take();
(env_ptr, thread_manager, true)
} else {
let mut env_ptr: *mut ort_sys::OrtEnv = std::ptr::null_mut();
let logging_function: ort_sys::OrtLoggingFunction = Some(custom_logger);
// FIXME: What should go here?
let logger_param: *mut std::ffi::c_void = std::ptr::null_mut();
let cname = CString::new(self.name.clone()).unwrap_or_else(|_| unreachable!());

#[cfg(feature = "tracing")]
ortsys![
unsafe CreateEnvWithCustomLogger(
logging_function,
logger_param,
Some(crate::logging::custom_logger),
ptr::null_mut(),
ort_sys::OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE,
cname.as_ptr(),
&mut env_ptr
)?;
nonNull(env_ptr)
];
#[cfg(not(feature = "tracing"))]
ortsys![
unsafe CreateEnv(
crate::logging::default_log_level(),
cname.as_ptr(),
&mut env_ptr
)?;
nonNull(env_ptr)
];

(env_ptr, None, false)
};
debug!(env_ptr = format!("{env_ptr:?}").as_str(), "Environment created");
crate::debug!(env_ptr = format!("{env_ptr:?}").as_str(), "Environment created");

if self.telemetry {
ortsys![unsafe EnableTelemetryEvents(env_ptr)?];
Expand Down Expand Up @@ -394,30 +412,3 @@ pub fn init_from(path: impl ToString) -> EnvironmentBuilder {
let _ = G_ORT_DYLIB_PATH.set(Arc::new(path.to_string()));
EnvironmentBuilder::new()
}

/// Callback from C that will handle ONNX logging, forwarding ONNX's logs to the `tracing` crate.
pub(crate) extern "system" fn custom_logger(
_params: *mut ffi::c_void,
severity: ort_sys::OrtLoggingLevel,
_: *const c_char,
id: *const c_char,
code_location: *const c_char,
message: *const c_char
) {
assert_ne!(code_location, ptr::null());
let code_location = unsafe { CStr::from_ptr(code_location) }.to_str().unwrap_or("<decode error>");
assert_ne!(message, ptr::null());
let message = unsafe { CStr::from_ptr(message) }.to_str().unwrap_or("<decode error>");
assert_ne!(id, ptr::null());
let id = unsafe { CStr::from_ptr(id) }.to_str().unwrap_or("<decode error>");

let span = tracing::span!(Level::TRACE, "ort", id = id, location = code_location);

match severity {
ort_sys::OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE => tracing::event!(parent: &span, Level::TRACE, "{message}"),
ort_sys::OrtLoggingLevel::ORT_LOGGING_LEVEL_INFO => tracing::event!(parent: &span, Level::INFO, "{message}"),
ort_sys::OrtLoggingLevel::ORT_LOGGING_LEVEL_WARNING => tracing::event!(parent: &span, Level::WARN, "{message}"),
ort_sys::OrtLoggingLevel::ORT_LOGGING_LEVEL_ERROR => tracing::event!(parent: &span, Level::ERROR, "{message}"),
ort_sys::OrtLoggingLevel::ORT_LOGGING_LEVEL_FATAL => tracing::event!(parent: &span, Level::ERROR, "(FATAL): {message}")
}
}
10 changes: 5 additions & 5 deletions src/execution_providers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -263,20 +263,20 @@ pub(crate) fn apply_execution_providers(
.ends_with("was not registered because its corresponding Cargo feature is not enabled.")
{
if ex.inner.supported_by_platform() {
tracing::warn!("{e}");
crate::warn!("{e}");
} else {
tracing::debug!("{e} (note: additionally, `{}` is not supported on this platform)", ex.inner.as_str());
crate::debug!("{e} (note: additionally, `{}` is not supported on this platform)", ex.inner.as_str());
}
} else {
tracing::error!("An error occurred when attempting to register `{}`: {e}", ex.inner.as_str());
crate::error!("An error occurred when attempting to register `{}`: {e}", ex.inner.as_str());
}
} else {
tracing::info!("Successfully registered `{}`", ex.inner.as_str());
crate::info!("Successfully registered `{}`", ex.inner.as_str());
fallback_to_cpu = false;
}
}
if fallback_to_cpu {
tracing::warn!("No execution providers registered successfully. Falling back to CPU.");
crate::warn!("No execution providers registered successfully. Falling back to CPU.");
}
Ok(())
}
6 changes: 4 additions & 2 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ pub mod environment;
pub mod error;
pub mod execution_providers;
pub mod io_binding;
pub(crate) mod logging;
pub mod memory;
pub mod metadata;
pub mod operator;
Expand All @@ -39,6 +40,7 @@ pub use ort_sys as sys;

#[cfg(feature = "load-dynamic")]
pub use self::environment::init_from;
pub(crate) use self::logging::{debug, error, info, trace, warning as warn};
pub use self::{
environment::init,
error::{Error, ErrorCode, Result}
Expand Down Expand Up @@ -138,7 +140,7 @@ pub fn api() -> &'static ort_sys::OrtApi {

let version_string = ((*base).GetVersionString)();
let version_string = CStr::from_ptr(version_string).to_string_lossy();
tracing::info!("Loaded ONNX Runtime dylib with version '{version_string}'");
crate::info!("Loaded ONNX Runtime dylib with version '{version_string}'");

let lib_minor_version = version_string.split('.').nth(1).map_or(0, |x| x.parse::<u32>().unwrap_or(0));
match lib_minor_version.cmp(&MINOR_VERSION) {
Expand All @@ -147,7 +149,7 @@ pub fn api() -> &'static ort_sys::OrtApi {
env!("CARGO_PKG_VERSION"),
dylib_path()
),
std::cmp::Ordering::Greater => tracing::warn!(
std::cmp::Ordering::Greater => crate::warn!(
"ort {} may have compatibility issues with the ONNX Runtime binary found at `{}`; expected GetVersionString to return '1.{MINOR_VERSION}.x', but got '{version_string}'",
env!("CARGO_PKG_VERSION"),
dylib_path()
Expand Down
81 changes: 81 additions & 0 deletions src/logging.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
#[cfg(feature = "tracing")]
use std::{
ffi::{self, CStr},
ptr
};

macro_rules! trace {
($($arg:tt)+) => {
#[cfg(feature = "tracing")]
tracing::trace!($($arg)+);
}
}
macro_rules! debug {
($($arg:tt)+) => {
#[cfg(feature = "tracing")]
tracing::debug!($($arg)+);
}
}
macro_rules! info {
($($arg:tt)+) => {
#[cfg(feature = "tracing")]
tracing::info!($($arg)+);
}
}
macro_rules! warning {
($($arg:tt)+) => {
#[cfg(feature = "tracing")]
tracing::warn!($($arg)+);
}
}
macro_rules! error {
($($arg:tt)+) => {
#[cfg(feature = "tracing")]
tracing::error!($($arg)+);
}
}
pub(crate) use debug;
pub(crate) use error;
pub(crate) use info;
pub(crate) use trace;
pub(crate) use warning;

#[cfg(not(feature = "tracing"))]
pub fn default_log_level() -> ort_sys::OrtLoggingLevel {
match std::env::var("ORT_LOG").as_deref() {
Ok("fatal") => ort_sys::OrtLoggingLevel::ORT_LOGGING_LEVEL_FATAL,
Ok("error") => ort_sys::OrtLoggingLevel::ORT_LOGGING_LEVEL_ERROR,
Ok("warning") => ort_sys::OrtLoggingLevel::ORT_LOGGING_LEVEL_WARNING,
Ok("info") => ort_sys::OrtLoggingLevel::ORT_LOGGING_LEVEL_INFO,
Ok("verbose") => ort_sys::OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE,
_ => ort_sys::OrtLoggingLevel::ORT_LOGGING_LEVEL_ERROR
}
}

/// Callback from C that will handle ONNX logging, forwarding ONNX's logs to the `tracing` crate.
#[cfg(feature = "tracing")]
pub(crate) extern "system" fn custom_logger(
_params: *mut ffi::c_void,
severity: ort_sys::OrtLoggingLevel,
_: *const ffi::c_char,
id: *const ffi::c_char,
code_location: *const ffi::c_char,
message: *const ffi::c_char
) {
assert_ne!(code_location, ptr::null());
let code_location = unsafe { CStr::from_ptr(code_location) }.to_str().unwrap_or("<decode error>");
assert_ne!(message, ptr::null());
let message = unsafe { CStr::from_ptr(message) }.to_str().unwrap_or("<decode error>");
assert_ne!(id, ptr::null());
let id = unsafe { CStr::from_ptr(id) }.to_str().unwrap_or("<decode error>");

let span = tracing::span!(tracing::Level::TRACE, "ort", id = id, location = code_location);

match severity {
ort_sys::OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE => tracing::event!(parent: &span, tracing::Level::TRACE, "{message}"),
ort_sys::OrtLoggingLevel::ORT_LOGGING_LEVEL_INFO => tracing::event!(parent: &span, tracing::Level::INFO, "{message}"),
ort_sys::OrtLoggingLevel::ORT_LOGGING_LEVEL_WARNING => tracing::event!(parent: &span, tracing::Level::WARN, "{message}"),
ort_sys::OrtLoggingLevel::ORT_LOGGING_LEVEL_ERROR => tracing::event!(parent: &span, tracing::Level::ERROR, "{message}"),
ort_sys::OrtLoggingLevel::ORT_LOGGING_LEVEL_FATAL => tracing::event!(parent: &span, tracing::Level::ERROR, "(FATAL): {message}")
}
}
6 changes: 3 additions & 3 deletions src/session/builder/impl_commit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,18 +32,18 @@ impl SessionBuilder {
});
let model_filepath = download_dir.join(&model_filename);
let downloaded_path = if model_filepath.exists() {
tracing::info!(model_filepath = format!("{}", model_filepath.display()).as_str(), "Model already exists, skipping download");
crate::info!(model_filepath = format!("{}", model_filepath.display()).as_str(), "Model already exists, skipping download");
model_filepath
} else {
tracing::info!(model_filepath = format!("{}", model_filepath.display()).as_str(), url = format!("{url:?}").as_str(), "Downloading model");
crate::info!(model_filepath = format!("{}", model_filepath.display()).as_str(), url = format!("{url:?}").as_str(), "Downloading model");

let resp = ureq::get(url).call().map_err(|e| Error::new(format!("Error downloading to file: {e}")))?;

let len = resp
.header("Content-Length")
.and_then(|s| s.parse::<usize>().ok())
.expect("Missing Content-Length header");
tracing::info!(len, "Downloading {} bytes", len);
crate::info!(len, "Downloading {} bytes", len);

let mut reader = resp.into_reader();
let temp_filepath = download_dir.join(format!("tmp_{}.{model_filename}", ort_sys::internal::random_identifier()));
Expand Down
2 changes: 1 addition & 1 deletion src/session/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ impl AsPointer for SharedSessionInner {

impl Drop for SharedSessionInner {
fn drop(&mut self) {
tracing::debug!(ptr = ?self.session_ptr.as_ptr(), "dropping SharedSessionInner");
crate::debug!(ptr = ?self.session_ptr.as_ptr(), "dropping SharedSessionInner");
ortsys![unsafe ReleaseSession(self.session_ptr.as_ptr())];
}
}
Expand Down
2 changes: 1 addition & 1 deletion src/training/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ impl AsPointer for Checkpoint {

impl Drop for Checkpoint {
fn drop(&mut self) {
tracing::trace!("dropping checkpoint");
crate::trace!("dropping checkpoint");
trainsys![unsafe ReleaseCheckpointState(self.ptr.as_ptr())];
}
}
Expand Down
2 changes: 1 addition & 1 deletion src/training/trainer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ impl AsPointer for Trainer {

impl Drop for Trainer {
fn drop(&mut self) {
tracing::trace!("dropping trainer");
crate::trace!("dropping trainer");
trainsys![unsafe ReleaseTrainingSession(self.ptr.as_ptr())];
}
}
2 changes: 1 addition & 1 deletion src/value/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ impl AsPointer for ValueInner {
impl Drop for ValueInner {
fn drop(&mut self) {
let ptr = self.ptr_mut();
tracing::trace!("dropping value at {ptr:p}");
crate::trace!("dropping value at {ptr:p}");
if self.drop {
ortsys![unsafe ReleaseValue(ptr)];
}
Expand Down

0 comments on commit 988e92d

Please sign in to comment.