diff --git a/zstd-safe/src/lib.rs b/zstd-safe/src/lib.rs index 2a8d7c1b..c1104fa0 100644 --- a/zstd-safe/src/lib.rs +++ b/zstd-safe/src/lib.rs @@ -748,6 +748,38 @@ impl<'a> CCtx<'a> { pub fn out_size() -> usize { unsafe { zstd_sys::ZSTD_CStreamOutSize() } } + + /// Use a shared thread pool for this context. + /// + /// Thread pool must outlive the context. + #[cfg(all(feature = "experimental", feature = "zstdmt"))] + #[cfg_attr( + feature = "doc-cfg", + doc(cfg(all(feature = "experimental", feature = "zstdmt"))) + )] + pub fn ref_thread_pool<'b>(&mut self, pool: &'b ThreadPool) -> SafeResult + where + 'b: 'a, + { + parse_code(unsafe { + zstd_sys::ZSTD_CCtx_refThreadPool(self.0.as_ptr(), pool.0.as_ptr()) + }) + } + + /// Return to using a private thread pool for this context. + #[cfg(all(feature = "experimental", feature = "zstdmt"))] + #[cfg_attr( + feature = "doc-cfg", + doc(cfg(all(feature = "experimental", feature = "zstdmt"))) + )] + pub fn disable_thread_pool(&mut self) -> SafeResult { + parse_code(unsafe { + zstd_sys::ZSTD_CCtx_refThreadPool( + self.0.as_ptr(), + core::ptr::null_mut(), + ) + }) + } } impl<'a> Drop for CCtx<'a> { @@ -1355,6 +1387,64 @@ impl<'a> Drop for DDict<'a> { unsafe impl<'a> Send for DDict<'a> {} unsafe impl<'a> Sync for DDict<'a> {} +/// A shared thread pool for one or more compression contexts +#[cfg(all(feature = "experimental", feature = "zstdmt"))] +#[cfg_attr( + feature = "doc-cfg", + doc(cfg(all(feature = "experimental", feature = "zstdmt"))) +)] +pub struct ThreadPool(NonNull); + +#[cfg(all(feature = "experimental", feature = "zstdmt"))] +#[cfg_attr( + feature = "doc-cfg", + doc(cfg(all(feature = "experimental", feature = "zstdmt"))) +)] +impl ThreadPool { + /// Create a thread pool with the specified number of threads. + /// + /// # Panics + /// + /// If creating the thread pool failed. + pub fn new(num_threads: usize) -> Self { + Self::try_new(num_threads) + .expect("zstd returned null pointer when creating thread pool") + } + + /// Create a thread pool with the specified number of threads. + pub fn try_new(num_threads: usize) -> Option { + Some(Self(NonNull::new(unsafe { + zstd_sys::ZSTD_createThreadPool(num_threads) + })?)) + } +} + +#[cfg(all(feature = "experimental", feature = "zstdmt"))] +#[cfg_attr( + feature = "doc-cfg", + doc(cfg(all(feature = "experimental", feature = "zstdmt"))) +)] +impl Drop for ThreadPool { + fn drop(&mut self) { + unsafe { + zstd_sys::ZSTD_freeThreadPool(self.0.as_ptr()); + } + } +} + +#[cfg(all(feature = "experimental", feature = "zstdmt"))] +#[cfg_attr( + feature = "doc-cfg", + doc(cfg(all(feature = "experimental", feature = "zstdmt"))) +)] +unsafe impl Send for ThreadPool {} +#[cfg(all(feature = "experimental", feature = "zstdmt"))] +#[cfg_attr( + feature = "doc-cfg", + doc(cfg(all(feature = "experimental", feature = "zstdmt"))) +)] +unsafe impl Sync for ThreadPool {} + /// Wraps the `ZSTD_decompress_usingDDict()` function. pub fn decompress_using_ddict( dctx: &mut DCtx<'_>,