Skip to content

Commit

Permalink
feat: compute hashes while extracting
Browse files Browse the repository at this point in the history
  • Loading branch information
baszalmstra committed Apr 19, 2023
1 parent 301dd72 commit e16af90
Show file tree
Hide file tree
Showing 12 changed files with 369 additions and 130 deletions.
5 changes: 4 additions & 1 deletion crates/rattler/src/package_cache.rs
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,9 @@ impl PackageCache {
) -> Result<PathBuf, PackageCacheError> {
self.get_or_fetch(pkg, move |destination| async move {
tracing::debug!("downloading {} to {}", &url, destination.display());
rattler_package_streaming::reqwest::tokio::extract(client, url, &destination).await
rattler_package_streaming::reqwest::tokio::extract(client, url, &destination)
.await
.map(|_| ())
})
.await
}
Expand Down Expand Up @@ -268,6 +270,7 @@ mod test {
move |destination| async move {
rattler_package_streaming::tokio::fs::extract(&tar_archive_path, &destination)
.await
.map(|_| ())
},
)
.await
Expand Down
59 changes: 59 additions & 0 deletions crates/rattler_digest/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ pub mod serde;
pub use digest;

use digest::{Digest, Output};
use std::io::Read;
use std::{fs::File, io::Write, path::Path};

pub use md5::Md5;
Expand Down Expand Up @@ -129,9 +130,49 @@ impl<W: Write, D: Digest> Write for HashingWriter<W, D> {
}
}

/// A simple object that provides a [`Read`] implementation that also immediately hashes the bytes
/// read from it. Call [`HashingReader::finalize`] to retrieve both the original `impl Read`
/// object as well as the hash.
///
/// If the `tokio` feature is enabled this object also implements [`::tokio::io::AsyncRead`] which
/// allows you to use it in an async context as well.
pub struct HashingReader<R, D: Digest> {
reader: R,
hasher: D,
}

impl<R, D: Digest + Default> HashingReader<R, D> {
/// Constructs a new instance from a reader and a new (empty) hasher.
pub fn new(reader: R) -> Self {
Self {
reader,
hasher: Default::default(),
}
}
}

impl<R, D: Digest> HashingReader<R, D> {
/// Consumes this instance and returns the original reader and the hash of all bytes read from
/// this instance.
pub fn finalize(self) -> (R, Output<D>) {
(self.reader, self.hasher.finalize())
}
}

impl<R: Read, D: Digest> Read for HashingReader<R, D> {
fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
let bytes_read = self.reader.read(buf)?;
self.hasher.update(&buf[..bytes_read]);
Ok(bytes_read)
}
}

#[cfg(test)]
mod test {
use super::HashingReader;
use rstest::rstest;
use sha2::Sha256;
use std::io::Read;

#[rstest]
#[case(
Expand All @@ -153,4 +194,22 @@ mod test {

assert_eq!(format!("{hash:x}"), expected_hash)
}

#[rstest]
#[case(
"1234567890",
"c775e7b757ede630cd0aa1113bd102661ab38829ca52a6422ab782862f268646"
)]
#[case(
"Hello, world!",
"315f5bdb76d078c43b8ac0064e4a0164612b1fce77c869345bfc94c75894edd3"
)]
fn test_hashing_reader_sha256(#[case] input: &str, #[case] expected_hash: &str) {
let mut cursor = HashingReader::<_, Sha256>::new(std::io::Cursor::new(input));
let mut cursor_string = String::new();
cursor.read_to_string(&mut cursor_string).unwrap();
assert_eq!(&cursor_string, input);
let (_, hash) = cursor.finalize();
assert_eq!(format!("{hash:x}"), expected_hash)
}
}
28 changes: 27 additions & 1 deletion crates/rattler_digest/src/tokio.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
use super::HashingWriter;
use crate::HashingReader;
use digest::Digest;
use std::{
io::Error,
pin::Pin,
task::{Context, Poll},
};
use tokio::io::AsyncWrite;
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};

impl<W: AsyncWrite + Unpin, D: Digest> AsyncWrite for HashingWriter<W, D> {
fn poll_write(
Expand Down Expand Up @@ -40,3 +41,28 @@ impl<W: AsyncWrite + Unpin, D: Digest> AsyncWrite for HashingWriter<W, D> {
writer.poll_flush(cx)
}
}

impl<R: AsyncRead + Unpin, D: Digest> AsyncRead for HashingReader<R, D> {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<std::io::Result<()>> {
let previously_filled = buf.filled().len();

// pin-project the reader
let (reader, hasher) = unsafe {
let this = self.get_unchecked_mut();
(Pin::new_unchecked(&mut this.reader), &mut this.hasher)
};

match reader.poll_read(cx, buf) {
Poll::Ready(Ok(result)) => {
let filled_part = buf.filled();
hasher.update(&filled_part[previously_filled..]);
Poll::Ready(Ok(result))
}
other => other,
}
}
}
21 changes: 12 additions & 9 deletions crates/rattler_package_streaming/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,20 +11,21 @@ license.workspace = true
readme.workspace = true

[dependencies]
thiserror = "1.0.37"
tar = { version = "0.4.38" }
bzip2 = { version = "0.4" }
zip = { version = "0.6.3" }
zstd = "0.12.1"
reqwest = { version = "0.11.13", optional = true }
tokio = { version = "1", optional = true }
tokio-util = { version = "0.7", optional = true }
chrono = "0.4.24"
futures-util = { version = "0.3.25", optional = true }
rattler_conda_types = { version = "0.2.0", path = "../rattler_conda_types" }
itertools = "0.10.5"
rattler_conda_types = { version = "0.2.0", path = "../rattler_conda_types" }
rattler_digest = { version = "0.2.0", path = "../rattler_digest" }
reqwest = { version = "0.11.13", optional = true }
serde_json = "1.0.94"
tar = { version = "0.4.38" }
thiserror = "1.0.37"
tokio = { version = "1", optional = true }
tokio-util = { version = "0.7", optional = true }
url = "2.3.1"
chrono = "0.4.24"
zip = { version = "0.6.3" }
zstd = "0.12.1"

[features]
tokio = ["dep:tokio", "bzip2/tokio", "tokio/fs", "tokio-util/io", "tokio-util/io-util", "reqwest?/stream", "futures-util"]
Expand All @@ -34,3 +35,5 @@ reqwest = ["reqwest/blocking"]
tempfile = "3.4.0"
tokio = { version = "1", features = ["rt", "macros"] }
walkdir = "2.3.2"
rstest = "0.17.0"
rstest_reuse = "0.5.0"
8 changes: 4 additions & 4 deletions crates/rattler_package_streaming/src/fs.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
//! Functions to extracting or stream a Conda package from a file on disk.
use crate::ExtractError;
use crate::{ExtractError, ExtractResult};
use rattler_conda_types::package::ArchiveType;
use std::fs::File;
use std::path::Path;
Expand All @@ -15,7 +15,7 @@ use std::path::Path;
/// Path::new("/tmp"))
/// .unwrap();
/// ```
pub fn extract_tar_bz2(archive: &Path, destination: &Path) -> Result<(), ExtractError> {
pub fn extract_tar_bz2(archive: &Path, destination: &Path) -> Result<ExtractResult, ExtractError> {
let file = File::open(archive)?;
crate::read::extract_tar_bz2(file, destination)
}
Expand All @@ -30,7 +30,7 @@ pub fn extract_tar_bz2(archive: &Path, destination: &Path) -> Result<(), Extract
/// Path::new("/tmp"))
/// .unwrap();
/// ```
pub fn extract_conda(archive: &Path, destination: &Path) -> Result<(), ExtractError> {
pub fn extract_conda(archive: &Path, destination: &Path) -> Result<ExtractResult, ExtractError> {
let file = File::open(archive)?;
crate::read::extract_conda(file, destination)
}
Expand All @@ -46,7 +46,7 @@ pub fn extract_conda(archive: &Path, destination: &Path) -> Result<(), ExtractEr
/// Path::new("/tmp"))
/// .unwrap();
/// ```
pub fn extract(archive: &Path, destination: &Path) -> Result<(), ExtractError> {
pub fn extract(archive: &Path, destination: &Path) -> Result<ExtractResult, ExtractError> {
match ArchiveType::try_from(archive).ok_or(ExtractError::UnsupportedArchiveType)? {
ArchiveType::TarBz2 => extract_tar_bz2(archive, destination),
ArchiveType::Conda => extract_conda(archive, destination),
Expand Down
12 changes: 12 additions & 0 deletions crates/rattler_package_streaming/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

//! This crate provides the ability to extract a Conda package archive or specific parts of it.
use rattler_digest::{Md5Hash, Sha256Hash};

pub mod read;
pub mod seek;

Expand Down Expand Up @@ -42,3 +44,13 @@ pub enum ExtractError {
#[error("the task was cancelled")]
Cancelled,
}

/// Result struct returned by extraction functions.
#[derive(Debug)]
pub struct ExtractResult {
/// The SHA256 hash of the extracted archive.
pub sha256: Sha256Hash,

/// The Md5 hash of the extracted archive.
pub md5: Md5Hash,
}
52 changes: 43 additions & 9 deletions crates/rattler_package_streaming/src/read.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
//! Functions that enable extracting or streaming a Conda package for objects that implement the
//! [`std::io::Read`] trait.
use super::ExtractError;
use std::ffi::OsStr;
use std::{io::Read, path::Path};
use super::{ExtractError, ExtractResult};
use std::{ffi::OsStr, io::Read, path::Path};
use zip::read::read_zipfile_from_stream;

/// Returns the `.tar.bz2` as a decompressed `tar::Archive`. The `tar::Archive` can be used to
Expand All @@ -21,19 +20,41 @@ pub(crate) fn stream_tar_zst(
}

/// Extracts the contents a `.tar.bz2` package archive.
pub fn extract_tar_bz2(reader: impl Read, destination: &Path) -> Result<(), ExtractError> {
pub fn extract_tar_bz2(
reader: impl Read,
destination: &Path,
) -> Result<ExtractResult, ExtractError> {
std::fs::create_dir_all(destination).map_err(ExtractError::CouldNotCreateDestination)?;
stream_tar_bz2(reader).unpack(destination)?;
Ok(())

// Wrap the reading in aditional readers that will compute the hashes of the file while its
// being read.
let sha256_reader = rattler_digest::HashingReader::<_, rattler_digest::Sha256>::new(reader);
let mut md5_reader =
rattler_digest::HashingReader::<_, rattler_digest::Md5>::new(sha256_reader);

// Unpack the archive
stream_tar_bz2(&mut md5_reader).unpack(destination)?;

// Get the hashes
let (sha256_reader, md5) = md5_reader.finalize();
let (_, sha256) = sha256_reader.finalize();

Ok(ExtractResult { sha256, md5 })
}

/// Extracts the contents of a `.conda` package archive.
pub fn extract_conda(mut reader: impl Read, destination: &Path) -> Result<(), ExtractError> {
pub fn extract_conda(reader: impl Read, destination: &Path) -> Result<ExtractResult, ExtractError> {
// Construct the destination path if it doesnt exist yet
std::fs::create_dir_all(destination).map_err(ExtractError::CouldNotCreateDestination)?;

// Wrap the reading in aditional readers that will compute the hashes of the file while its
// being read.
let sha256_reader = rattler_digest::HashingReader::<_, rattler_digest::Sha256>::new(reader);
let mut md5_reader =
rattler_digest::HashingReader::<_, rattler_digest::Md5>::new(sha256_reader);

// Iterate over all entries in the zip-file and extract them one-by-one
while let Some(file) = read_zipfile_from_stream(&mut reader)? {
while let Some(file) = read_zipfile_from_stream(&mut md5_reader)? {
if file
.mangled_name()
.file_name()
Expand All @@ -44,5 +65,18 @@ pub fn extract_conda(mut reader: impl Read, destination: &Path) -> Result<(), Ex
}
}

Ok(())
// Read the file to the end to make sure the hash is properly computed.
let mut buf = [0; 1 << 14];
loop {
let bytes_read = md5_reader.read(&mut buf)?;
if bytes_read == 0 {
break;
}
}

// Get the hashes
let (sha256_reader, md5) = md5_reader.finalize();
let (_, sha256) = sha256_reader.finalize();

Ok(ExtractResult { sha256, md5 })
}
12 changes: 8 additions & 4 deletions crates/rattler_package_streaming/src/reqwest/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
#[cfg(feature = "tokio")]
pub mod tokio;

use crate::ExtractError;
use crate::{ExtractError, ExtractResult};
use rattler_conda_types::package::ArchiveType;
use reqwest::blocking::{Client, Response};
use reqwest::IntoUrl;
Expand All @@ -25,7 +25,7 @@ pub fn extract_tar_bz2(
client: Client,
url: impl IntoUrl,
destination: &Path,
) -> Result<(), ExtractError> {
) -> Result<ExtractResult, ExtractError> {
// Send the request for the file
let response = client
.get(url)
Expand Down Expand Up @@ -53,7 +53,7 @@ pub fn extract_conda(
client: Client,
url: impl IntoUrl,
destination: &Path,
) -> Result<(), ExtractError> {
) -> Result<ExtractResult, ExtractError> {
// Send the request for the file
let response = client
.get(url)
Expand All @@ -78,7 +78,11 @@ pub fn extract_conda(
/// Path::new("/tmp"))
/// .unwrap();
/// ```
pub fn extract(client: Client, url: impl IntoUrl, destination: &Path) -> Result<(), ExtractError> {
pub fn extract(
client: Client,
url: impl IntoUrl,
destination: &Path,
) -> Result<ExtractResult, ExtractError> {
let url = url
.into_url()
.map_err(reqwest::Error::from)
Expand Down
12 changes: 8 additions & 4 deletions crates/rattler_package_streaming/src/reqwest/tokio.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
//! Functionality to stream and extract packages directly from a [`reqwest::Url`] within a [`tokio`]
//! async context.
use crate::ExtractError;
use crate::{ExtractError, ExtractResult};
use futures_util::stream::TryStreamExt;
use rattler_conda_types::package::ArchiveType;
use reqwest::{Client, Response};
Expand Down Expand Up @@ -57,7 +57,7 @@ pub async fn extract_tar_bz2(
client: Client,
url: Url,
destination: &Path,
) -> Result<(), ExtractError> {
) -> Result<ExtractResult, ExtractError> {
let reader = get_reader(url.clone(), client).await?;
// The `response` is used to stream in the package data
crate::tokio::async_read::extract_tar_bz2(reader, destination).await
Expand All @@ -84,7 +84,7 @@ pub async fn extract_conda(
client: Client,
url: Url,
destination: &Path,
) -> Result<(), ExtractError> {
) -> Result<ExtractResult, ExtractError> {
// The `response` is used to stream in the package data
let reader = get_reader(url.clone(), client).await?;
crate::tokio::async_read::extract_conda(reader, destination).await
Expand All @@ -108,7 +108,11 @@ pub async fn extract_conda(
/// .unwrap();
/// # }
/// ```
pub async fn extract(client: Client, url: Url, destination: &Path) -> Result<(), ExtractError> {
pub async fn extract(
client: Client,
url: Url,
destination: &Path,
) -> Result<ExtractResult, ExtractError> {
match ArchiveType::try_from(Path::new(url.path()))
.ok_or(ExtractError::UnsupportedArchiveType)?
{
Expand Down
Loading

0 comments on commit e16af90

Please sign in to comment.