diff --git a/rust/pact_ffi/src/lib.rs b/rust/pact_ffi/src/lib.rs index 8939c6926..91c6dc692 100644 --- a/rust/pact_ffi/src/lib.rs +++ b/rust/pact_ffi/src/lib.rs @@ -12,7 +12,7 @@ use std::str::FromStr; use lazy_static::lazy_static; use libc::c_char; use tracing::{debug, error, info, trace, warn}; -use tracing_core::Level; +use tracing_core::{Level, LevelFilter}; use tracing_log::AsLog; use tracing_subscriber::FmtSubscriber; @@ -85,12 +85,14 @@ pub unsafe extern fn pactffi_init(log_env_var: *const c_char) { /// Initialises logging, and sets the log level explicitly. This function should only be called /// once, as it tries to install a global tracing subscriber. /// +/// * `log_level` - String (case-insensitive). One of TRACE, DEBUG, INFO, WARN, ERROR, NONE/OFF +/// /// # Safety /// /// Exported functions are inherently unsafe. #[no_mangle] pub unsafe extern "C" fn pactffi_init_with_log_level(level: *const c_char) { - let log_level = log_level_from_c_char(level); + let log_level = log_level_filter_from_c_char(level); let subscriber = FmtSubscriber::builder() .with_max_level(log_level) .with_thread_names(true) @@ -160,6 +162,18 @@ unsafe fn log_level_from_c_char(log_level: *const c_char) -> Level { } } +unsafe fn log_level_filter_from_c_char(log_level: *const c_char) -> LevelFilter { + if !log_level.is_null() { + let level = convert_cstr("log_level", log_level).unwrap_or("INFO"); + match level.to_lowercase().as_str() { + "none" => LevelFilter::OFF, + _ => LevelFilter::from_str(level).unwrap_or(LevelFilter::INFO) + } + } else { + LevelFilter::INFO + } +} + fn convert_cstr(name: &str, value: *const c_char) -> Option<&str> { unsafe { if value.is_null() { @@ -316,3 +330,36 @@ impl MismatchesIterator { idx } } + +#[cfg(test)] +mod tests { + use std::ffi::CString; + + use expectest::prelude::*; + use rstest::rstest; + + use super::log_level_filter_from_c_char; + use tracing_core::LevelFilter; + + #[rstest] + #[case("trace", LevelFilter::TRACE)] + #[case("TRACE", LevelFilter::TRACE)] + #[case("debug", LevelFilter::DEBUG)] + #[case("DEBUG", LevelFilter::DEBUG)] + #[case("info", LevelFilter::INFO)] + #[case("INFO", LevelFilter::INFO)] + #[case("warn", LevelFilter::WARN)] + #[case("WARN", LevelFilter::WARN)] + #[case("error", LevelFilter::ERROR)] + #[case("ERROR", LevelFilter::ERROR)] + #[case("off", LevelFilter::OFF)] + #[case("OFF", LevelFilter::OFF)] + #[case("none", LevelFilter::OFF)] + #[case("NONE", LevelFilter::OFF)] + #[case("invalid", LevelFilter::INFO)] + fn log_level_filter_from_c_char_test(#[case] text: String, #[case] level: LevelFilter) { + let value = CString::new(text).unwrap(); + let result = unsafe { log_level_filter_from_c_char(value.as_ptr()) }; + expect!(result).to(be_equal_to(level)); + } +}