Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add ProgressBar trait and progress bar for package writing #525

Merged
merged 3 commits into from
Feb 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
144 changes: 118 additions & 26 deletions crates/rattler_package_streaming/src/write.rs
Original file line number Diff line number Diff line change
@@ -1,29 +1,80 @@
//! Functionality for writing conda packages
use std::fs::{self, File};
use std::io::{self, Seek, Write};
use std::io::{self, Read, Seek, Write};
use std::path::{Path, PathBuf};

use itertools::sorted;

use rattler_conda_types::package::PackageMetadata;

/// Trait for progress bars
pub trait ProgressBar {
/// Set the current progress and progress message
fn set_progress(&mut self, progress: u64, message: &str);
/// Set the total amount of bytes
fn set_total(&mut self, total: u64);
}

/// A wrapper for a reader that updates a progress bar
struct ProgressBarReader {
reader: Option<File>,
progress_bar: Option<Box<dyn ProgressBar>>,
progress: u64,
total: u64,
message: String,
}

impl ProgressBarReader {
fn new(progress_bar: Option<Box<dyn ProgressBar>>) -> Self {
Self {
reader: None,
progress_bar,
progress: 0,
total: 0,
message: String::new(),
}
}

fn set_file(&mut self, file: File) {
self.reader = Some(file);
}

fn reset_position(&mut self) {
self.progress = 0;
if let Some(progress_bar) = &mut self.progress_bar {
progress_bar.set_progress(0, &self.message);
}
}

fn set_total(&mut self, total_size: u64) {
self.total = total_size;
if let Some(progress_bar) = &mut self.progress_bar {
progress_bar.set_total(total_size);
}
}
}

impl Read for ProgressBarReader {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
let n = self.reader.as_ref().expect("No reader set!").read(buf)?;
self.progress += n as u64;
if let Some(progress_bar) = &mut self.progress_bar {
progress_bar.set_progress(self.progress, &self.message);
}
Ok(n)
}
}

/// a function that sorts paths into two iterators, one that starts with `info/` and one that does not
/// both iterators are sorted alphabetically for reproducibility
fn sort_paths<'a>(
paths: &'a [PathBuf],
base_path: &'a Path,
) -> (
impl Iterator<Item = PathBuf> + 'a,
impl Iterator<Item = PathBuf> + 'a,
) {
fn sort_paths<'a>(paths: &'a [PathBuf], base_path: &'a Path) -> (Vec<PathBuf>, Vec<PathBuf>) {
let info = Path::new("info/");
let (info_paths, other_paths): (Vec<_>, Vec<_>) = paths
let (mut info_paths, mut other_paths): (Vec<_>, Vec<_>) = paths
.iter()
.map(|p| p.strip_prefix(base_path).unwrap())
.partition(|&path| path.starts_with(info));
.map(Path::to_path_buf)
.partition(|path| path.starts_with(info));

let info_paths = sorted(info_paths.into_iter().map(std::path::Path::to_path_buf));
let other_paths = sorted(other_paths.into_iter().map(std::path::Path::to_path_buf));
info_paths.sort();
other_paths.sort();

(info_paths, other_paths)
}
Expand Down Expand Up @@ -82,6 +133,13 @@ impl CompressionLevel {
}
}

fn total_size(base_path: &Path, paths: &[PathBuf]) -> u64 {
paths
.iter()
.map(|p| base_path.join(p).metadata().map(|m| m.len()).unwrap_or(0))
.sum()
}

/// Write the contents of a list of paths to a tar.bz2 package
/// The paths are sorted alphabetically, and paths beginning with `info/` come first.
///
Expand All @@ -107,7 +165,7 @@ impl CompressionLevel {
///
/// let paths = vec![PathBuf::from("info/recipe/meta.yaml"), PathBuf::from("info/recipe/conda_build_config.yaml")];
/// let mut file = File::create("test.tar.bz2").unwrap();
/// write_tar_bz2_package(&mut file, &PathBuf::from("test"), &paths, CompressionLevel::Default, None).unwrap();
/// write_tar_bz2_package(&mut file, &PathBuf::from("test"), &paths, CompressionLevel::Default, None, None).unwrap();
/// ```
///
/// # See also
Expand All @@ -119,17 +177,28 @@ pub fn write_tar_bz2_package<W: Write>(
paths: &[PathBuf],
compression_level: CompressionLevel,
timestamp: Option<&chrono::DateTime<chrono::Utc>>,
progress_bar: Option<Box<dyn ProgressBar>>,
) -> Result<(), std::io::Error> {
let mut archive = tar::Builder::new(bzip2::write::BzEncoder::new(
writer,
compression_level.to_bzip2_level()?,
));
archive.follow_symlinks(false);

let total_size = total_size(base_path, paths);
let mut progress_bar_wrapper = ProgressBarReader::new(progress_bar);
progress_bar_wrapper.set_total(total_size);

// sort paths alphabetically, and sort paths beginning with `info/` first
let (info_paths, other_paths) = sort_paths(paths, base_path);
for path in info_paths.chain(other_paths) {
append_path_to_archive(&mut archive, base_path, &path, timestamp)?;
for path in info_paths.iter().chain(other_paths.iter()) {
append_path_to_archive(
&mut archive,
base_path,
path,
timestamp,
&mut progress_bar_wrapper,
)?;
}

archive.into_inner()?.finish()?;
Expand All @@ -141,30 +210,47 @@ pub fn write_tar_bz2_package<W: Write>(
fn write_zst_archive<W: Write>(
writer: W,
base_path: &Path,
paths: impl Iterator<Item = PathBuf>,
paths: &Vec<PathBuf>,
compression_level: CompressionLevel,
num_threads: Option<u32>,
timestamp: Option<&chrono::DateTime<chrono::Utc>>,
progress_bar: Option<Box<dyn ProgressBar>>,
) -> Result<(), std::io::Error> {
// Create a temporary tar file
let tar_path = tempfile::Builder::new().tempfile_in(base_path)?;
let mut archive = tar::Builder::new(&tar_path);
archive.follow_symlinks(false);

let total_size = total_size(base_path, paths);
let mut progress_bar_wrapper = ProgressBarReader::new(progress_bar);
progress_bar_wrapper.set_total(total_size);
for path in paths {
append_path_to_archive(&mut archive, base_path, &path, timestamp)?;
append_path_to_archive(
&mut archive,
base_path,
path,
timestamp,
&mut progress_bar_wrapper,
)?;
}
archive.finish()?;

// Compress it as tar.zst
let mut tar_file = File::open(&tar_path)?;
let tar_file = File::open(&tar_path)?;
let compression_level = compression_level.to_zstd_level()?;
let mut zst_encoder = zstd::Encoder::new(writer, compression_level)?;
zst_encoder.multithread(num_threads.unwrap_or_else(|| num_cpus::get() as u32))?;
zst_encoder.set_pledged_src_size(tar_file.metadata().map(|v| v.len()).ok())?;

progress_bar_wrapper.reset_position();
if let Ok(tar_total_size) = tar_file.metadata().map(|v| v.len()) {
zst_encoder.set_pledged_src_size(Some(tar_total_size))?;
progress_bar_wrapper.set_total(tar_total_size);
};
zst_encoder.include_contentsize(true)?;

// Append tar.zst to the archive
io::copy(&mut tar_file, &mut zst_encoder)?;
progress_bar_wrapper.set_file(tar_file);
io::copy(&mut progress_bar_wrapper, &mut zst_encoder)?;
zst_encoder.finish()?;

Ok(())
Expand All @@ -190,6 +276,7 @@ fn write_zst_archive<W: Write>(
///
/// This function will return an error if the writer returns an error, or if the paths are not
/// relative to the base path.
#[allow(clippy::too_many_arguments)]
pub fn write_conda_package<W: Write + Seek>(
writer: W,
base_path: &Path,
Expand All @@ -198,6 +285,7 @@ pub fn write_conda_package<W: Write + Seek>(
compression_num_threads: Option<u32>,
out_name: &str,
timestamp: Option<&chrono::DateTime<chrono::Utc>>,
progress_bar: Option<Box<dyn ProgressBar>>,
) -> Result<(), std::io::Error> {
// first create the outer zip archive that uses no compression
let mut outer_archive = zip::ZipWriter::new(writer);
Expand All @@ -217,10 +305,11 @@ pub fn write_conda_package<W: Write + Seek>(
write_zst_archive(
&mut outer_archive,
base_path,
other_paths,
&other_paths,
compression_level,
compression_num_threads,
timestamp,
progress_bar,
)?;

// info paths come last
Expand All @@ -229,10 +318,11 @@ pub fn write_conda_package<W: Write + Seek>(
write_zst_archive(
&mut outer_archive,
base_path,
info_paths,
&info_paths,
compression_level,
compression_num_threads,
timestamp,
None,
)?;

outer_archive.finish()?;
Expand Down Expand Up @@ -271,6 +361,7 @@ fn append_path_to_archive(
base_path: &Path,
path: &Path,
timestamp: Option<&chrono::DateTime<chrono::Utc>>,
progress_bar: &mut ProgressBarReader,
) -> Result<(), std::io::Error> {
// create a tar header
let mut header = prepare_header(&base_path.join(path), timestamp)
Expand All @@ -279,8 +370,9 @@ fn append_path_to_archive(
if header.entry_type().is_file() {
let file = fs::File::open(base_path.join(path))
.map_err(|err| trace_file_error(&base_path.join(path), err))?;

archive.append_data(&mut header, path, &file)?;
// wrap the file reader in a progress bar reader
progress_bar.set_file(file);
archive.append_data(&mut header, path, progress_bar)?;
} else if header.entry_type().is_symlink() || header.entry_type().is_hard_link() {
let target = fs::read_link(base_path.join(path))
.map_err(|err| trace_file_error(&base_path.join(path), err))?;
Expand Down
12 changes: 10 additions & 2 deletions crates/rattler_package_streaming/tests/write.rs
Original file line number Diff line number Diff line change
Expand Up @@ -178,8 +178,15 @@ fn test_rewrite_tar_bz2() {

let writer = File::create(&new_archive).unwrap();
let paths = find_all_package_files(&target_dir);
write_tar_bz2_package(writer, &target_dir, &paths, CompressionLevel::Default, None)
.unwrap();
write_tar_bz2_package(
writer,
&target_dir,
&paths,
CompressionLevel::Default,
None,
None,
)
.unwrap();

// compare the two archives
let mut f1 = File::open(&file_path).unwrap();
Expand Down Expand Up @@ -219,6 +226,7 @@ fn test_rewrite_conda() {
None,
&name,
None,
None,
)
.unwrap();

Expand Down
Loading