diff --git a/crates/rattler_package_streaming/src/write.rs b/crates/rattler_package_streaming/src/write.rs index 9ff2ac0a1..6e5f62b01 100644 --- a/crates/rattler_package_streaming/src/write.rs +++ b/crates/rattler_package_streaming/src/write.rs @@ -1,29 +1,80 @@ //! Functionality for writing conda packages use std::fs::{self, File}; -use std::io::{self, Seek, Write}; +use std::io::{self, Read, Seek, Write}; use std::path::{Path, PathBuf}; -use itertools::sorted; - use rattler_conda_types::package::PackageMetadata; +/// Trait for progress bars +pub trait ProgressBar { + /// Set the current progress and progress message + fn set_progress(&mut self, progress: u64, message: &str); + /// Set the total amount of bytes + fn set_total(&mut self, total: u64); +} + +/// A wrapper for a reader that updates a progress bar +struct ProgressBarReader { + reader: Option, + progress_bar: Option>, + progress: u64, + total: u64, + message: String, +} + +impl ProgressBarReader { + fn new(progress_bar: Option>) -> Self { + Self { + reader: None, + progress_bar, + progress: 0, + total: 0, + message: String::new(), + } + } + + fn set_file(&mut self, file: File) { + self.reader = Some(file); + } + + fn reset_position(&mut self) { + self.progress = 0; + if let Some(progress_bar) = &mut self.progress_bar { + progress_bar.set_progress(0, &self.message); + } + } + + fn set_total(&mut self, total_size: u64) { + self.total = total_size; + if let Some(progress_bar) = &mut self.progress_bar { + progress_bar.set_total(total_size); + } + } +} + +impl Read for ProgressBarReader { + fn read(&mut self, buf: &mut [u8]) -> io::Result { + let n = self.reader.as_ref().expect("No reader set!").read(buf)?; + self.progress += n as u64; + if let Some(progress_bar) = &mut self.progress_bar { + progress_bar.set_progress(self.progress, &self.message); + } + Ok(n) + } +} + /// a function that sorts paths into two iterators, one that starts with `info/` and one that does not /// both iterators are sorted alphabetically for reproducibility -fn sort_paths<'a>( - paths: &'a [PathBuf], - base_path: &'a Path, -) -> ( - impl Iterator + 'a, - impl Iterator + 'a, -) { +fn sort_paths<'a>(paths: &'a [PathBuf], base_path: &'a Path) -> (Vec, Vec) { let info = Path::new("info/"); - let (info_paths, other_paths): (Vec<_>, Vec<_>) = paths + let (mut info_paths, mut other_paths): (Vec<_>, Vec<_>) = paths .iter() .map(|p| p.strip_prefix(base_path).unwrap()) - .partition(|&path| path.starts_with(info)); + .map(Path::to_path_buf) + .partition(|path| path.starts_with(info)); - let info_paths = sorted(info_paths.into_iter().map(std::path::Path::to_path_buf)); - let other_paths = sorted(other_paths.into_iter().map(std::path::Path::to_path_buf)); + info_paths.sort(); + other_paths.sort(); (info_paths, other_paths) } @@ -82,6 +133,13 @@ impl CompressionLevel { } } +fn total_size(base_path: &Path, paths: &[PathBuf]) -> u64 { + paths + .iter() + .map(|p| base_path.join(p).metadata().map(|m| m.len()).unwrap_or(0)) + .sum() +} + /// Write the contents of a list of paths to a tar.bz2 package /// The paths are sorted alphabetically, and paths beginning with `info/` come first. /// @@ -107,7 +165,7 @@ impl CompressionLevel { /// /// let paths = vec![PathBuf::from("info/recipe/meta.yaml"), PathBuf::from("info/recipe/conda_build_config.yaml")]; /// let mut file = File::create("test.tar.bz2").unwrap(); -/// write_tar_bz2_package(&mut file, &PathBuf::from("test"), &paths, CompressionLevel::Default, None).unwrap(); +/// write_tar_bz2_package(&mut file, &PathBuf::from("test"), &paths, CompressionLevel::Default, None, None).unwrap(); /// ``` /// /// # See also @@ -119,6 +177,7 @@ pub fn write_tar_bz2_package( paths: &[PathBuf], compression_level: CompressionLevel, timestamp: Option<&chrono::DateTime>, + progress_bar: Option>, ) -> Result<(), std::io::Error> { let mut archive = tar::Builder::new(bzip2::write::BzEncoder::new( writer, @@ -126,10 +185,20 @@ pub fn write_tar_bz2_package( )); archive.follow_symlinks(false); + let total_size = total_size(base_path, paths); + let mut progress_bar_wrapper = ProgressBarReader::new(progress_bar); + progress_bar_wrapper.set_total(total_size); + // sort paths alphabetically, and sort paths beginning with `info/` first let (info_paths, other_paths) = sort_paths(paths, base_path); - for path in info_paths.chain(other_paths) { - append_path_to_archive(&mut archive, base_path, &path, timestamp)?; + for path in info_paths.iter().chain(other_paths.iter()) { + append_path_to_archive( + &mut archive, + base_path, + path, + timestamp, + &mut progress_bar_wrapper, + )?; } archive.into_inner()?.finish()?; @@ -141,30 +210,47 @@ pub fn write_tar_bz2_package( fn write_zst_archive( writer: W, base_path: &Path, - paths: impl Iterator, + paths: &Vec, compression_level: CompressionLevel, num_threads: Option, timestamp: Option<&chrono::DateTime>, + progress_bar: Option>, ) -> Result<(), std::io::Error> { // Create a temporary tar file let tar_path = tempfile::Builder::new().tempfile_in(base_path)?; let mut archive = tar::Builder::new(&tar_path); archive.follow_symlinks(false); + + let total_size = total_size(base_path, paths); + let mut progress_bar_wrapper = ProgressBarReader::new(progress_bar); + progress_bar_wrapper.set_total(total_size); for path in paths { - append_path_to_archive(&mut archive, base_path, &path, timestamp)?; + append_path_to_archive( + &mut archive, + base_path, + path, + timestamp, + &mut progress_bar_wrapper, + )?; } archive.finish()?; // Compress it as tar.zst - let mut tar_file = File::open(&tar_path)?; + let tar_file = File::open(&tar_path)?; let compression_level = compression_level.to_zstd_level()?; let mut zst_encoder = zstd::Encoder::new(writer, compression_level)?; zst_encoder.multithread(num_threads.unwrap_or_else(|| num_cpus::get() as u32))?; - zst_encoder.set_pledged_src_size(tar_file.metadata().map(|v| v.len()).ok())?; + + progress_bar_wrapper.reset_position(); + if let Ok(tar_total_size) = tar_file.metadata().map(|v| v.len()) { + zst_encoder.set_pledged_src_size(Some(tar_total_size))?; + progress_bar_wrapper.set_total(tar_total_size); + }; zst_encoder.include_contentsize(true)?; // Append tar.zst to the archive - io::copy(&mut tar_file, &mut zst_encoder)?; + progress_bar_wrapper.set_file(tar_file); + io::copy(&mut progress_bar_wrapper, &mut zst_encoder)?; zst_encoder.finish()?; Ok(()) @@ -190,6 +276,7 @@ fn write_zst_archive( /// /// This function will return an error if the writer returns an error, or if the paths are not /// relative to the base path. +#[allow(clippy::too_many_arguments)] pub fn write_conda_package( writer: W, base_path: &Path, @@ -198,6 +285,7 @@ pub fn write_conda_package( compression_num_threads: Option, out_name: &str, timestamp: Option<&chrono::DateTime>, + progress_bar: Option>, ) -> Result<(), std::io::Error> { // first create the outer zip archive that uses no compression let mut outer_archive = zip::ZipWriter::new(writer); @@ -217,10 +305,11 @@ pub fn write_conda_package( write_zst_archive( &mut outer_archive, base_path, - other_paths, + &other_paths, compression_level, compression_num_threads, timestamp, + progress_bar, )?; // info paths come last @@ -229,10 +318,11 @@ pub fn write_conda_package( write_zst_archive( &mut outer_archive, base_path, - info_paths, + &info_paths, compression_level, compression_num_threads, timestamp, + None, )?; outer_archive.finish()?; @@ -271,6 +361,7 @@ fn append_path_to_archive( base_path: &Path, path: &Path, timestamp: Option<&chrono::DateTime>, + progress_bar: &mut ProgressBarReader, ) -> Result<(), std::io::Error> { // create a tar header let mut header = prepare_header(&base_path.join(path), timestamp) @@ -279,8 +370,9 @@ fn append_path_to_archive( if header.entry_type().is_file() { let file = fs::File::open(base_path.join(path)) .map_err(|err| trace_file_error(&base_path.join(path), err))?; - - archive.append_data(&mut header, path, &file)?; + // wrap the file reader in a progress bar reader + progress_bar.set_file(file); + archive.append_data(&mut header, path, progress_bar)?; } else if header.entry_type().is_symlink() || header.entry_type().is_hard_link() { let target = fs::read_link(base_path.join(path)) .map_err(|err| trace_file_error(&base_path.join(path), err))?; diff --git a/crates/rattler_package_streaming/tests/write.rs b/crates/rattler_package_streaming/tests/write.rs index 56ba970e1..b1b31860c 100644 --- a/crates/rattler_package_streaming/tests/write.rs +++ b/crates/rattler_package_streaming/tests/write.rs @@ -178,8 +178,15 @@ fn test_rewrite_tar_bz2() { let writer = File::create(&new_archive).unwrap(); let paths = find_all_package_files(&target_dir); - write_tar_bz2_package(writer, &target_dir, &paths, CompressionLevel::Default, None) - .unwrap(); + write_tar_bz2_package( + writer, + &target_dir, + &paths, + CompressionLevel::Default, + None, + None, + ) + .unwrap(); // compare the two archives let mut f1 = File::open(&file_path).unwrap(); @@ -219,6 +226,7 @@ fn test_rewrite_conda() { None, &name, None, + None, ) .unwrap();