diff --git a/src/call.rs b/src/call.rs index a18f05da91..9aa3ae667f 100644 --- a/src/call.rs +++ b/src/call.rs @@ -44,9 +44,7 @@ pub fn c_try(ret: libc::c_int) -> Result { } pub fn last_error(code: libc::c_int) -> Error { - // nowadays this unwrap is safe as `Error::last_error` always returns - // `Some`. - Error::last_error(code).unwrap() + Error::last_error(code) } mod impls { diff --git a/src/error.rs b/src/error.rs index e57bae27dd..076667af98 100644 --- a/src/error.rs +++ b/src/error.rs @@ -32,12 +32,7 @@ impl Error { /// /// The `code` argument typically comes from the return value of a function /// call. This code will later be returned from the `code` function. - /// - /// Historically this function returned `Some` or `None` based on the return - /// value of `git_error_last` but nowadays it always returns `Some` so it's - /// safe to unwrap the return value. This API will change in the next major - /// version. - pub fn last_error(code: c_int) -> Option { + pub fn last_error(code: c_int) -> Error { crate::init(); unsafe { // Note that whenever libgit2 returns an error any negative value @@ -64,7 +59,7 @@ impl Error { Error::from_raw(code, ptr) }; raw::git_error_clear(); - Some(err) + err } } diff --git a/src/indexer.rs b/src/indexer.rs index 0aaf353d53..ddca5fa2d5 100644 --- a/src/indexer.rs +++ b/src/indexer.rs @@ -188,10 +188,7 @@ impl io::Write for Indexer<'_> { let res = raw::git_indexer_append(self.raw, ptr, len, &mut self.progress); if res < 0 { - Err(io::Error::new( - io::ErrorKind::Other, - Error::last_error(res).unwrap(), - )) + Err(io::Error::new(io::ErrorKind::Other, Error::last_error(res))) } else { Ok(buf.len()) } diff --git a/src/odb.rs b/src/odb.rs index d01c70ae67..2019908c48 100644 --- a/src/odb.rs +++ b/src/odb.rs @@ -458,7 +458,7 @@ impl<'repo> OdbPackwriter<'repo> { }; if res < 0 { - Err(Error::last_error(res).unwrap()) + Err(Error::last_error(res)) } else { Ok(res) } diff --git a/src/repo.rs b/src/repo.rs index b94b4007db..2b3e60b2af 100644 --- a/src/repo.rs +++ b/src/repo.rs @@ -848,7 +848,7 @@ impl Repository { match value { 0 => Ok(false), 1 => Ok(true), - _ => Err(Error::last_error(value).unwrap()), + _ => Err(Error::last_error(value)), } } } diff --git a/src/tracing.rs b/src/tracing.rs index 5acae8a850..9872571dd3 100644 --- a/src/tracing.rs +++ b/src/tracing.rs @@ -1,8 +1,11 @@ -use std::sync::atomic::{AtomicUsize, Ordering}; +use std::{ + ffi::CStr, + sync::atomic::{AtomicPtr, Ordering}, +}; -use libc::c_char; +use libc::{c_char, c_int}; -use crate::{panic, raw, util::Binding}; +use crate::{panic, raw, util::Binding, Error}; /// Available tracing levels. When tracing is set to a particular level, /// callers will be provided tracing at the given level and all lower levels. @@ -57,29 +60,81 @@ impl Binding for TraceLevel { } } -//TODO: pass raw &[u8] and leave conversion to consumer (breaking API) /// Callback type used to pass tracing events to the subscriber. /// see `trace_set` to register a subscriber. -pub type TracingCb = fn(TraceLevel, &str); +pub type TracingCb = fn(TraceLevel, &[u8]); -static CALLBACK: AtomicUsize = AtomicUsize::new(0); +/// Use an atomic pointer to store the global tracing subscriber function. +static CALLBACK: AtomicPtr<()> = AtomicPtr::new(std::ptr::null_mut()); -/// -pub fn trace_set(level: TraceLevel, cb: TracingCb) -> bool { - CALLBACK.store(cb as usize, Ordering::SeqCst); +/// Set the global subscriber called when libgit2 produces a tracing message. +pub fn trace_set(level: TraceLevel, cb: TracingCb) -> Result<(), Error> { + // Store the callback in the global atomic. + CALLBACK.store(cb as *mut (), Ordering::SeqCst); - unsafe { - raw::git_trace_set(level.raw(), Some(tracing_cb_c)); - } + // git_trace_set returns 0 if there was no error. + let return_code: c_int = unsafe { raw::git_trace_set(level.raw(), Some(tracing_cb_c)) }; - return true; + if return_code != 0 { + Err(Error::last_error(return_code)) + } else { + Ok(()) + } } +/// The tracing callback we pass to libgit2 (C ABI compatible). extern "C" fn tracing_cb_c(level: raw::git_trace_level_t, msg: *const c_char) { - let cb = CALLBACK.load(Ordering::SeqCst); - panic::wrap(|| unsafe { - let cb: TracingCb = std::mem::transmute(cb); - let msg = std::ffi::CStr::from_ptr(msg).to_string_lossy(); - cb(Binding::from_raw(level), msg.as_ref()); + // Load the callback function pointer from the global atomic. + let cb: *mut () = CALLBACK.load(Ordering::SeqCst); + + // Transmute the callback pointer into the function pointer we know it to be. + // + // SAFETY: We only ever set the callback pointer with something cast from a TracingCb + // so transmuting back to a TracingCb is safe. This is notably not an integer-to-pointer + // transmute as described in the mem::transmute documentation and is in-line with the + // example in that documentation for casing between *const () to fn pointers. + let cb: TracingCb = unsafe { std::mem::transmute(cb) }; + + // If libgit2 passes us a message that is null, drop it and do not pass it to the callback. + // This is to avoid ever exposing rust code to a null ref, which would be Undefined Behavior. + if msg.is_null() { + return; + } + + // Convert the message from a *const c_char to a &[u8] and pass it to the callback. + // + // SAFETY: We've just checked that the pointer is not null. The other safety requirements are left to + // libgit2 to enforce -- namely that it gives us a valid, nul-terminated, C string, that that string exists + // entirely in one allocation, that the string will not be mutated once passed to us, and that the nul-terminator is + // within isize::MAX bytes from the given pointers data address. + let msg: &CStr = unsafe { CStr::from_ptr(msg) }; + + // Convert from a CStr to &[u8] to pass to the rust code callback. + let msg: &[u8] = CStr::to_bytes(msg); + + // Do the remaining part of this function in a panic wrapper, to catch any panics it produces. + panic::wrap(|| { + // Convert the raw trace level into a type we can pass to the rust callback fn. + // + // SAFETY: Currently the implementation of this function (above) may panic, but is only marked as unsafe to match + // the trait definition, thus we can consider this call safe. + let level: TraceLevel = unsafe { Binding::from_raw(level) }; + + // Call the user-supplied callback (which may panic). + (cb)(level, msg); }); } + +#[cfg(test)] +mod tests { + use super::TraceLevel; + + // Test that using the above function to set a tracing callback doesn't panic. + #[test] + fn smoke() { + super::trace_set(TraceLevel::Trace, |level, msg| { + dbg!(level, msg); + }) + .expect("libgit2 can set global trace callback"); + } +}