Skip to content

Commit

Permalink
implement progress bar via trait and reader-wrapper
Browse files Browse the repository at this point in the history
  • Loading branch information
wolfv committed Feb 12, 2024
1 parent 38ae450 commit 1a0d24e
Showing 1 changed file with 117 additions and 26 deletions.
143 changes: 117 additions & 26 deletions crates/rattler_package_streaming/src/write.rs
Original file line number Diff line number Diff line change
@@ -1,29 +1,80 @@
//! Functionality for writing conda packages
use std::fs::{self, File};
use std::io::{self, Seek, Write};
use std::io::{self, Read, Seek, Write};
use std::path::{Path, PathBuf};

use itertools::sorted;

use rattler_conda_types::package::PackageMetadata;

/// Trait for progress bars
pub trait ProgressBar {
/// Set the current progress and progress message
fn set_progress(&mut self, progress: u64, message: &str);
/// Set the total amount of bytes
fn set_total(&mut self, total: u64);
}

/// A wrapper for a reader that updates a progress bar
struct ProgressBarReader {
reader: Option<File>,
progress_bar: Option<Box<dyn ProgressBar>>,
progress: u64,
total: u64,
message: String,
}

impl ProgressBarReader {
fn new(progress_bar: Option<Box<dyn ProgressBar>>) -> Self {
Self {
reader: None,
progress_bar,
progress: 0,
total: 0,
message: String::new(),
}
}

fn set_file(&mut self, file: File) {
self.reader = Some(file);
}

fn reset_position(&mut self) {
self.progress = 0;
if let Some(progress_bar) = &mut self.progress_bar {
progress_bar.set_progress(0, &self.message);
}
}

fn set_total(&mut self, total_size: u64) {
self.total = total_size;
if let Some(progress_bar) = &mut self.progress_bar {
progress_bar.set_total(total_size);
}
}
}

impl Read for ProgressBarReader {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
let n = self.reader.as_ref().expect("No reader set!").read(buf)?;
self.progress += n as u64;
if let Some(progress_bar) = &mut self.progress_bar {
progress_bar.set_progress(self.progress, &self.message);
}
Ok(n)
}
}

/// a function that sorts paths into two iterators, one that starts with `info/` and one that does not
/// both iterators are sorted alphabetically for reproducibility
fn sort_paths<'a>(
paths: &'a [PathBuf],
base_path: &'a Path,
) -> (
impl Iterator<Item = PathBuf> + 'a,
impl Iterator<Item = PathBuf> + 'a,
) {
fn sort_paths<'a>(paths: &'a [PathBuf], base_path: &'a Path) -> (Vec<PathBuf>, Vec<PathBuf>) {
let info = Path::new("info/");
let (info_paths, other_paths): (Vec<_>, Vec<_>) = paths
let (mut info_paths, mut other_paths): (Vec<_>, Vec<_>) = paths
.iter()
.map(|p| p.strip_prefix(base_path).unwrap())
.partition(|&path| path.starts_with(info));
.map(|p| p.to_path_buf())
.partition(|path| path.starts_with(info));

let info_paths = sorted(info_paths.into_iter().map(std::path::Path::to_path_buf));
let other_paths = sorted(other_paths.into_iter().map(std::path::Path::to_path_buf));
info_paths.sort();
other_paths.sort();

(info_paths, other_paths)
}
Expand Down Expand Up @@ -82,6 +133,13 @@ impl CompressionLevel {
}
}

fn total_size<'a>(base_path: &Path, paths: &[PathBuf]) -> u64 {
paths
.iter()
.map(|p| base_path.join(p).metadata().map(|m| m.len()).unwrap_or(0))
.sum()
}

/// Write the contents of a list of paths to a tar.bz2 package
/// The paths are sorted alphabetically, and paths beginning with `info/` come first.
///
Expand Down Expand Up @@ -119,17 +177,28 @@ pub fn write_tar_bz2_package<W: Write>(
paths: &[PathBuf],
compression_level: CompressionLevel,
timestamp: Option<&chrono::DateTime<chrono::Utc>>,
progress_bar: Option<Box<dyn ProgressBar>>,
) -> Result<(), std::io::Error> {
let mut archive = tar::Builder::new(bzip2::write::BzEncoder::new(
writer,
compression_level.to_bzip2_level()?,
));
archive.follow_symlinks(false);

let total_size = total_size(base_path, &paths);
let mut progress_bar_wrapper = ProgressBarReader::new(progress_bar);
progress_bar_wrapper.set_total(total_size);

// sort paths alphabetically, and sort paths beginning with `info/` first
let (info_paths, other_paths) = sort_paths(paths, base_path);
for path in info_paths.chain(other_paths) {
append_path_to_archive(&mut archive, base_path, &path, timestamp)?;
for path in info_paths.iter().chain(other_paths.iter()) {
append_path_to_archive(
&mut archive,
base_path,
&path,
timestamp,
&mut progress_bar_wrapper,
)?;
}

archive.into_inner()?.finish()?;
Expand All @@ -138,33 +207,50 @@ pub fn write_tar_bz2_package<W: Write>(
}

/// Write the contents of a list of paths to a tar zst archive
fn write_zst_archive<W: Write>(
fn write_zst_archive<'a, W: Write>(
writer: W,
base_path: &Path,
paths: impl Iterator<Item = PathBuf>,
paths: &Vec<PathBuf>,
compression_level: CompressionLevel,
num_threads: Option<u32>,
timestamp: Option<&chrono::DateTime<chrono::Utc>>,
progress_bar: Option<Box<dyn ProgressBar>>,
) -> Result<(), std::io::Error> {
// Create a temporary tar file
let tar_path = tempfile::Builder::new().tempfile_in(base_path)?;
let mut archive = tar::Builder::new(&tar_path);
archive.follow_symlinks(false);

let total_size = total_size(base_path, &paths);
let mut progress_bar_wrapper = ProgressBarReader::new(progress_bar);
progress_bar_wrapper.set_total(total_size);
for path in paths {
append_path_to_archive(&mut archive, base_path, &path, timestamp)?;
append_path_to_archive(
&mut archive,
base_path,
&path,
timestamp,
&mut progress_bar_wrapper,
)?;
}
archive.finish()?;

// Compress it as tar.zst
let mut tar_file = File::open(&tar_path)?;
let tar_file = File::open(&tar_path)?;
let compression_level = compression_level.to_zstd_level()?;
let mut zst_encoder = zstd::Encoder::new(writer, compression_level)?;
zst_encoder.multithread(num_threads.unwrap_or_else(|| num_cpus::get() as u32))?;
zst_encoder.set_pledged_src_size(tar_file.metadata().map(|v| v.len()).ok())?;

progress_bar_wrapper.reset_position();
if let Some(tar_total_size) = tar_file.metadata().map(|v| v.len()).ok() {
zst_encoder.set_pledged_src_size(Some(tar_total_size))?;
progress_bar_wrapper.set_total(tar_total_size);
};
zst_encoder.include_contentsize(true)?;

// Append tar.zst to the archive
io::copy(&mut tar_file, &mut zst_encoder)?;
progress_bar_wrapper.set_file(tar_file);
io::copy(&mut progress_bar_wrapper, &mut zst_encoder)?;
zst_encoder.finish()?;

Ok(())
Expand Down Expand Up @@ -198,6 +284,7 @@ pub fn write_conda_package<W: Write + Seek>(
compression_num_threads: Option<u32>,
out_name: &str,
timestamp: Option<&chrono::DateTime<chrono::Utc>>,
progress_bar: Option<Box<dyn ProgressBar>>,
) -> Result<(), std::io::Error> {
// first create the outer zip archive that uses no compression
let mut outer_archive = zip::ZipWriter::new(writer);
Expand All @@ -217,10 +304,11 @@ pub fn write_conda_package<W: Write + Seek>(
write_zst_archive(
&mut outer_archive,
base_path,
other_paths,
&other_paths,
compression_level,
compression_num_threads,
timestamp,
progress_bar,
)?;

// info paths come last
Expand All @@ -229,10 +317,11 @@ pub fn write_conda_package<W: Write + Seek>(
write_zst_archive(
&mut outer_archive,
base_path,
info_paths,
&info_paths,
compression_level,
compression_num_threads,
timestamp,
None,
)?;

outer_archive.finish()?;
Expand Down Expand Up @@ -271,6 +360,7 @@ fn append_path_to_archive(
base_path: &Path,
path: &Path,
timestamp: Option<&chrono::DateTime<chrono::Utc>>,
progress_bar: &mut ProgressBarReader,
) -> Result<(), std::io::Error> {
// create a tar header
let mut header = prepare_header(&base_path.join(path), timestamp)
Expand All @@ -279,8 +369,9 @@ fn append_path_to_archive(
if header.entry_type().is_file() {
let file = fs::File::open(base_path.join(path))
.map_err(|err| trace_file_error(&base_path.join(path), err))?;

archive.append_data(&mut header, path, &file)?;
// wrap the file reader in a progress bar reader
progress_bar.set_file(file);
archive.append_data(&mut header, path, progress_bar)?;
} else if header.entry_type().is_symlink() || header.entry_type().is_hard_link() {
let target = fs::read_link(base_path.join(path))
.map_err(|err| trace_file_error(&base_path.join(path), err))?;
Expand Down

0 comments on commit 1a0d24e

Please sign in to comment.