From aa46f5a76ffa7742c77a7d6e329c903af579e1a7 Mon Sep 17 00:00:00 2001 From: Bas Zalmstra Date: Fri, 5 May 2023 15:27:31 +0200 Subject: [PATCH] feat: compute hashes while extracting (#176) * feat: compute hashes while extracting * fix: build --- crates/rattler/src/package_cache.rs | 5 +- crates/rattler_digest/src/lib.rs | 59 ++++ crates/rattler_digest/src/tokio.rs | 28 +- crates/rattler_package_streaming/Cargo.toml | 21 +- crates/rattler_package_streaming/src/fs.rs | 8 +- crates/rattler_package_streaming/src/lib.rs | 12 + crates/rattler_package_streaming/src/read.rs | 52 +++- .../src/reqwest/mod.rs | 12 +- .../src/reqwest/tokio.rs | 12 +- .../src/tokio/async_read.rs | 6 +- .../rattler_package_streaming/src/tokio/fs.rs | 14 +- .../tests/extract.rs | 270 ++++++++++++------ 12 files changed, 369 insertions(+), 130 deletions(-) diff --git a/crates/rattler/src/package_cache.rs b/crates/rattler/src/package_cache.rs index ffa03a898..00eac39c3 100644 --- a/crates/rattler/src/package_cache.rs +++ b/crates/rattler/src/package_cache.rs @@ -184,7 +184,9 @@ impl PackageCache { ) -> Result { self.get_or_fetch(pkg, move |destination| async move { tracing::debug!("downloading {} to {}", &url, destination.display()); - rattler_package_streaming::reqwest::tokio::extract(client, url, &destination).await + rattler_package_streaming::reqwest::tokio::extract(client, url, &destination) + .await + .map(|_| ()) }) .await } @@ -268,6 +270,7 @@ mod test { move |destination| async move { rattler_package_streaming::tokio::fs::extract(&tar_archive_path, &destination) .await + .map(|_| ()) }, ) .await diff --git a/crates/rattler_digest/src/lib.rs b/crates/rattler_digest/src/lib.rs index db3355ea8..132b8f49a 100644 --- a/crates/rattler_digest/src/lib.rs +++ b/crates/rattler_digest/src/lib.rs @@ -47,6 +47,7 @@ pub mod serde; pub use digest; use digest::{Digest, Output}; +use std::io::Read; use std::{fs::File, io::Write, path::Path}; pub use md5::Md5; @@ -129,9 +130,49 @@ impl Write for HashingWriter { } } +/// A simple object that provides a [`Read`] implementation that also immediately hashes the bytes +/// read from it. Call [`HashingReader::finalize`] to retrieve both the original `impl Read` +/// object as well as the hash. +/// +/// If the `tokio` feature is enabled this object also implements [`::tokio::io::AsyncRead`] which +/// allows you to use it in an async context as well. +pub struct HashingReader { + reader: R, + hasher: D, +} + +impl HashingReader { + /// Constructs a new instance from a reader and a new (empty) hasher. + pub fn new(reader: R) -> Self { + Self { + reader, + hasher: Default::default(), + } + } +} + +impl HashingReader { + /// Consumes this instance and returns the original reader and the hash of all bytes read from + /// this instance. + pub fn finalize(self) -> (R, Output) { + (self.reader, self.hasher.finalize()) + } +} + +impl Read for HashingReader { + fn read(&mut self, buf: &mut [u8]) -> std::io::Result { + let bytes_read = self.reader.read(buf)?; + self.hasher.update(&buf[..bytes_read]); + Ok(bytes_read) + } +} + #[cfg(test)] mod test { + use super::HashingReader; use rstest::rstest; + use sha2::Sha256; + use std::io::Read; #[rstest] #[case( @@ -153,4 +194,22 @@ mod test { assert_eq!(format!("{hash:x}"), expected_hash) } + + #[rstest] + #[case( + "1234567890", + "c775e7b757ede630cd0aa1113bd102661ab38829ca52a6422ab782862f268646" + )] + #[case( + "Hello, world!", + "315f5bdb76d078c43b8ac0064e4a0164612b1fce77c869345bfc94c75894edd3" + )] + fn test_hashing_reader_sha256(#[case] input: &str, #[case] expected_hash: &str) { + let mut cursor = HashingReader::<_, Sha256>::new(std::io::Cursor::new(input)); + let mut cursor_string = String::new(); + cursor.read_to_string(&mut cursor_string).unwrap(); + assert_eq!(&cursor_string, input); + let (_, hash) = cursor.finalize(); + assert_eq!(format!("{hash:x}"), expected_hash) + } } diff --git a/crates/rattler_digest/src/tokio.rs b/crates/rattler_digest/src/tokio.rs index 76d0fc82d..530774343 100644 --- a/crates/rattler_digest/src/tokio.rs +++ b/crates/rattler_digest/src/tokio.rs @@ -1,11 +1,12 @@ use super::HashingWriter; +use crate::HashingReader; use digest::Digest; use std::{ io::Error, pin::Pin, task::{Context, Poll}, }; -use tokio::io::AsyncWrite; +use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; impl AsyncWrite for HashingWriter { fn poll_write( @@ -40,3 +41,28 @@ impl AsyncWrite for HashingWriter { writer.poll_flush(cx) } } + +impl AsyncRead for HashingReader { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + let previously_filled = buf.filled().len(); + + // pin-project the reader + let (reader, hasher) = unsafe { + let this = self.get_unchecked_mut(); + (Pin::new_unchecked(&mut this.reader), &mut this.hasher) + }; + + match reader.poll_read(cx, buf) { + Poll::Ready(Ok(result)) => { + let filled_part = buf.filled(); + hasher.update(&filled_part[previously_filled..]); + Poll::Ready(Ok(result)) + } + other => other, + } + } +} diff --git a/crates/rattler_package_streaming/Cargo.toml b/crates/rattler_package_streaming/Cargo.toml index 3e3704843..732883831 100644 --- a/crates/rattler_package_streaming/Cargo.toml +++ b/crates/rattler_package_streaming/Cargo.toml @@ -11,20 +11,21 @@ license.workspace = true readme.workspace = true [dependencies] -thiserror = "1.0.37" -tar = { version = "0.4.38" } bzip2 = { version = "0.4" } -zip = { version = "0.6.3" } -zstd = "0.12.1" -reqwest = { version = "0.11.13", optional = true, default-features = false } -tokio = { version = "1", optional = true } -tokio-util = { version = "0.7", optional = true } +chrono = "0.4.24" futures-util = { version = "0.3.25", optional = true } -rattler_conda_types = { version = "0.2.0", path = "../rattler_conda_types" } itertools = "0.10.5" +rattler_conda_types = { version = "0.2.0", path = "../rattler_conda_types" } +rattler_digest = { version = "0.2.0", path = "../rattler_digest" } serde_json = "1.0.94" +tar = { version = "0.4.38" } +thiserror = "1.0.37" +tokio = { version = "1", optional = true } +tokio-util = { version = "0.7", optional = true } +reqwest = { version = "0.11.13", optional = true, default-features = false } url = "2.3.1" -chrono = "0.4.24" +zip = { version = "0.6.3" } +zstd = "0.12.1" [features] default = ['native-tls'] @@ -37,3 +38,5 @@ reqwest = ["reqwest/blocking"] tempfile = "3.4.0" tokio = { version = "1", features = ["rt", "macros"] } walkdir = "2.3.2" +rstest = "0.17.0" +rstest_reuse = "0.5.0" diff --git a/crates/rattler_package_streaming/src/fs.rs b/crates/rattler_package_streaming/src/fs.rs index f84bfb3d9..b7c5278a4 100644 --- a/crates/rattler_package_streaming/src/fs.rs +++ b/crates/rattler_package_streaming/src/fs.rs @@ -1,6 +1,6 @@ //! Functions to extracting or stream a Conda package from a file on disk. -use crate::ExtractError; +use crate::{ExtractError, ExtractResult}; use rattler_conda_types::package::ArchiveType; use std::fs::File; use std::path::Path; @@ -15,7 +15,7 @@ use std::path::Path; /// Path::new("/tmp")) /// .unwrap(); /// ``` -pub fn extract_tar_bz2(archive: &Path, destination: &Path) -> Result<(), ExtractError> { +pub fn extract_tar_bz2(archive: &Path, destination: &Path) -> Result { let file = File::open(archive)?; crate::read::extract_tar_bz2(file, destination) } @@ -30,7 +30,7 @@ pub fn extract_tar_bz2(archive: &Path, destination: &Path) -> Result<(), Extract /// Path::new("/tmp")) /// .unwrap(); /// ``` -pub fn extract_conda(archive: &Path, destination: &Path) -> Result<(), ExtractError> { +pub fn extract_conda(archive: &Path, destination: &Path) -> Result { let file = File::open(archive)?; crate::read::extract_conda(file, destination) } @@ -46,7 +46,7 @@ pub fn extract_conda(archive: &Path, destination: &Path) -> Result<(), ExtractEr /// Path::new("/tmp")) /// .unwrap(); /// ``` -pub fn extract(archive: &Path, destination: &Path) -> Result<(), ExtractError> { +pub fn extract(archive: &Path, destination: &Path) -> Result { match ArchiveType::try_from(archive).ok_or(ExtractError::UnsupportedArchiveType)? { ArchiveType::TarBz2 => extract_tar_bz2(archive, destination), ArchiveType::Conda => extract_conda(archive, destination), diff --git a/crates/rattler_package_streaming/src/lib.rs b/crates/rattler_package_streaming/src/lib.rs index a437c64d6..f851f9276 100644 --- a/crates/rattler_package_streaming/src/lib.rs +++ b/crates/rattler_package_streaming/src/lib.rs @@ -2,6 +2,8 @@ //! This crate provides the ability to extract a Conda package archive or specific parts of it. +use rattler_digest::{Md5Hash, Sha256Hash}; + pub mod read; pub mod seek; @@ -42,3 +44,13 @@ pub enum ExtractError { #[error("the task was cancelled")] Cancelled, } + +/// Result struct returned by extraction functions. +#[derive(Debug)] +pub struct ExtractResult { + /// The SHA256 hash of the extracted archive. + pub sha256: Sha256Hash, + + /// The Md5 hash of the extracted archive. + pub md5: Md5Hash, +} diff --git a/crates/rattler_package_streaming/src/read.rs b/crates/rattler_package_streaming/src/read.rs index 90dcae1d7..cb26a6f9f 100644 --- a/crates/rattler_package_streaming/src/read.rs +++ b/crates/rattler_package_streaming/src/read.rs @@ -1,9 +1,8 @@ //! Functions that enable extracting or streaming a Conda package for objects that implement the //! [`std::io::Read`] trait. -use super::ExtractError; -use std::ffi::OsStr; -use std::{io::Read, path::Path}; +use super::{ExtractError, ExtractResult}; +use std::{ffi::OsStr, io::Read, path::Path}; use zip::read::read_zipfile_from_stream; /// Returns the `.tar.bz2` as a decompressed `tar::Archive`. The `tar::Archive` can be used to @@ -21,19 +20,41 @@ pub(crate) fn stream_tar_zst( } /// Extracts the contents a `.tar.bz2` package archive. -pub fn extract_tar_bz2(reader: impl Read, destination: &Path) -> Result<(), ExtractError> { +pub fn extract_tar_bz2( + reader: impl Read, + destination: &Path, +) -> Result { std::fs::create_dir_all(destination).map_err(ExtractError::CouldNotCreateDestination)?; - stream_tar_bz2(reader).unpack(destination)?; - Ok(()) + + // Wrap the reading in aditional readers that will compute the hashes of the file while its + // being read. + let sha256_reader = rattler_digest::HashingReader::<_, rattler_digest::Sha256>::new(reader); + let mut md5_reader = + rattler_digest::HashingReader::<_, rattler_digest::Md5>::new(sha256_reader); + + // Unpack the archive + stream_tar_bz2(&mut md5_reader).unpack(destination)?; + + // Get the hashes + let (sha256_reader, md5) = md5_reader.finalize(); + let (_, sha256) = sha256_reader.finalize(); + + Ok(ExtractResult { sha256, md5 }) } /// Extracts the contents of a `.conda` package archive. -pub fn extract_conda(mut reader: impl Read, destination: &Path) -> Result<(), ExtractError> { +pub fn extract_conda(reader: impl Read, destination: &Path) -> Result { // Construct the destination path if it doesnt exist yet std::fs::create_dir_all(destination).map_err(ExtractError::CouldNotCreateDestination)?; + // Wrap the reading in aditional readers that will compute the hashes of the file while its + // being read. + let sha256_reader = rattler_digest::HashingReader::<_, rattler_digest::Sha256>::new(reader); + let mut md5_reader = + rattler_digest::HashingReader::<_, rattler_digest::Md5>::new(sha256_reader); + // Iterate over all entries in the zip-file and extract them one-by-one - while let Some(file) = read_zipfile_from_stream(&mut reader)? { + while let Some(file) = read_zipfile_from_stream(&mut md5_reader)? { if file .mangled_name() .file_name() @@ -44,5 +65,18 @@ pub fn extract_conda(mut reader: impl Read, destination: &Path) -> Result<(), Ex } } - Ok(()) + // Read the file to the end to make sure the hash is properly computed. + let mut buf = [0; 1 << 14]; + loop { + let bytes_read = md5_reader.read(&mut buf)?; + if bytes_read == 0 { + break; + } + } + + // Get the hashes + let (sha256_reader, md5) = md5_reader.finalize(); + let (_, sha256) = sha256_reader.finalize(); + + Ok(ExtractResult { sha256, md5 }) } diff --git a/crates/rattler_package_streaming/src/reqwest/mod.rs b/crates/rattler_package_streaming/src/reqwest/mod.rs index cc3a452bd..56972bfd6 100644 --- a/crates/rattler_package_streaming/src/reqwest/mod.rs +++ b/crates/rattler_package_streaming/src/reqwest/mod.rs @@ -3,7 +3,7 @@ #[cfg(feature = "tokio")] pub mod tokio; -use crate::ExtractError; +use crate::{ExtractError, ExtractResult}; use rattler_conda_types::package::ArchiveType; use reqwest::blocking::{Client, Response}; use reqwest::IntoUrl; @@ -25,7 +25,7 @@ pub fn extract_tar_bz2( client: Client, url: impl IntoUrl, destination: &Path, -) -> Result<(), ExtractError> { +) -> Result { // Send the request for the file let response = client .get(url) @@ -53,7 +53,7 @@ pub fn extract_conda( client: Client, url: impl IntoUrl, destination: &Path, -) -> Result<(), ExtractError> { +) -> Result { // Send the request for the file let response = client .get(url) @@ -78,7 +78,11 @@ pub fn extract_conda( /// Path::new("/tmp")) /// .unwrap(); /// ``` -pub fn extract(client: Client, url: impl IntoUrl, destination: &Path) -> Result<(), ExtractError> { +pub fn extract( + client: Client, + url: impl IntoUrl, + destination: &Path, +) -> Result { let url = url .into_url() .map_err(reqwest::Error::from) diff --git a/crates/rattler_package_streaming/src/reqwest/tokio.rs b/crates/rattler_package_streaming/src/reqwest/tokio.rs index 7a45ea347..6a29963f1 100644 --- a/crates/rattler_package_streaming/src/reqwest/tokio.rs +++ b/crates/rattler_package_streaming/src/reqwest/tokio.rs @@ -1,7 +1,7 @@ //! Functionality to stream and extract packages directly from a [`reqwest::Url`] within a [`tokio`] //! async context. -use crate::ExtractError; +use crate::{ExtractError, ExtractResult}; use futures_util::stream::TryStreamExt; use rattler_conda_types::package::ArchiveType; use reqwest::{Client, Response}; @@ -57,7 +57,7 @@ pub async fn extract_tar_bz2( client: Client, url: Url, destination: &Path, -) -> Result<(), ExtractError> { +) -> Result { let reader = get_reader(url.clone(), client).await?; // The `response` is used to stream in the package data crate::tokio::async_read::extract_tar_bz2(reader, destination).await @@ -84,7 +84,7 @@ pub async fn extract_conda( client: Client, url: Url, destination: &Path, -) -> Result<(), ExtractError> { +) -> Result { // The `response` is used to stream in the package data let reader = get_reader(url.clone(), client).await?; crate::tokio::async_read::extract_conda(reader, destination).await @@ -108,7 +108,11 @@ pub async fn extract_conda( /// .unwrap(); /// # } /// ``` -pub async fn extract(client: Client, url: Url, destination: &Path) -> Result<(), ExtractError> { +pub async fn extract( + client: Client, + url: Url, + destination: &Path, +) -> Result { match ArchiveType::try_from(Path::new(url.path())) .ok_or(ExtractError::UnsupportedArchiveType)? { diff --git a/crates/rattler_package_streaming/src/tokio/async_read.rs b/crates/rattler_package_streaming/src/tokio/async_read.rs index d1d27fe33..97b1d7833 100644 --- a/crates/rattler_package_streaming/src/tokio/async_read.rs +++ b/crates/rattler_package_streaming/src/tokio/async_read.rs @@ -1,7 +1,7 @@ //! Functions that enable extracting or streaming a Conda package for objects that implement the //! [`tokio::io::AsyncRead`] trait. -use crate::ExtractError; +use crate::{ExtractError, ExtractResult}; use std::path::Path; use tokio::io::AsyncRead; use tokio_util::io::SyncIoBridge; @@ -10,7 +10,7 @@ use tokio_util::io::SyncIoBridge; pub async fn extract_tar_bz2( reader: impl AsyncRead + Send + 'static, destination: &Path, -) -> Result<(), ExtractError> { +) -> Result { // Create a async -> sync bridge let reader = SyncIoBridge::new(Box::pin(reader)); @@ -33,7 +33,7 @@ pub async fn extract_tar_bz2( pub async fn extract_conda( reader: impl AsyncRead + Send + 'static, destination: &Path, -) -> Result<(), ExtractError> { +) -> Result { // Create a async -> sync bridge let reader = SyncIoBridge::new(Box::pin(reader)); diff --git a/crates/rattler_package_streaming/src/tokio/fs.rs b/crates/rattler_package_streaming/src/tokio/fs.rs index 521130d67..7a9ba08a3 100644 --- a/crates/rattler_package_streaming/src/tokio/fs.rs +++ b/crates/rattler_package_streaming/src/tokio/fs.rs @@ -1,6 +1,6 @@ //! Functions to extracting or stream a Conda package from a file on disk. -use crate::ExtractError; +use crate::{ExtractError, ExtractResult}; use rattler_conda_types::package::ArchiveType; use std::path::Path; @@ -18,7 +18,10 @@ use std::path::Path; /// .unwrap(); /// # } /// ``` -pub async fn extract_tar_bz2(archive: &Path, destination: &Path) -> Result<(), ExtractError> { +pub async fn extract_tar_bz2( + archive: &Path, + destination: &Path, +) -> Result { // Spawn a block task to perform the extraction let destination = destination.to_owned(); let archive = archive.to_owned(); @@ -49,7 +52,10 @@ pub async fn extract_tar_bz2(archive: &Path, destination: &Path) -> Result<(), E /// .unwrap(); /// # } /// ``` -pub async fn extract_conda(archive: &Path, destination: &Path) -> Result<(), ExtractError> { +pub async fn extract_conda( + archive: &Path, + destination: &Path, +) -> Result { // Spawn a block task to perform the extraction let destination = destination.to_owned(); let archive = archive.to_owned(); @@ -81,7 +87,7 @@ pub async fn extract_conda(archive: &Path, destination: &Path) -> Result<(), Ext /// .unwrap(); /// # } /// ``` -pub async fn extract(archive: &Path, destination: &Path) -> Result<(), ExtractError> { +pub async fn extract(archive: &Path, destination: &Path) -> Result { match ArchiveType::try_from(archive).ok_or(ExtractError::UnsupportedArchiveType)? { ArchiveType::TarBz2 => extract_tar_bz2(archive, destination).await, ArchiveType::Conda => extract_conda(archive, destination).await, diff --git a/crates/rattler_package_streaming/tests/extract.rs b/crates/rattler_package_streaming/tests/extract.rs index 1d1e5db30..a8a5de57b 100644 --- a/crates/rattler_package_streaming/tests/extract.rs +++ b/crates/rattler_package_streaming/tests/extract.rs @@ -1,147 +1,235 @@ -use rattler_conda_types::package::ArchiveType; use rattler_package_streaming::read::{extract_conda, extract_tar_bz2}; +use rstest::rstest; +use rstest_reuse::{self, *}; use std::fs::File; use std::path::{Path, PathBuf}; -fn find_all_archives() -> impl Iterator { - std::fs::read_dir(Path::new(env!("CARGO_MANIFEST_DIR")).join("../../test-data")) - .unwrap() - .filter_map(Result::ok) - .map(|d| d.path()) +fn test_data_dir() -> PathBuf { + Path::new(env!("CARGO_MANIFEST_DIR")).join("../../test-data") } -#[test] -fn test_extract_conda() { +#[template] +#[rstest] +#[case::conda( + "conda-22.11.1-py38haa244fe_1.conda", + "a8a44c5ff2b2f423546d49721ba2e3e632233c74a813c944adf8e5742834930e", + "9987c96161034575f5a9c2be848960c5" +)] +#[case::mamba( + "mamba-1.1.0-py39hb3d9227_2.conda", + "c172acdf9cb7655dd224879b30361a657b09bb084b65f151e36a2b51e51a080a", + "d87eb6aecfc0fe58299e6d6cfb252a7f" +)] +#[case::mock( + "mock-2.0.0-py37_1000.conda", + "181ec44eb7b06ebb833eae845bcc466ad96474be1f33ee55cab7ac1b0fdbbfa3", + "23c226430e35a3bd994db6c36b9ac8ae" +)] +#[case::mujoco( + "mujoco-2.3.1-ha3edaa6_0.conda", + "007f27a98a150ac3fbbd5bdd708d35f807ba2e117a194f218b130890d461ce77", + "910c94e2d1234e98196c4a64a82ff07e" +)] +#[case::ruff( + "ruff-0.0.171-py310h298983d_0.conda", + "25c755b97189ee066576b4ae3999d5e7ff4406d236b984742194e63941838dcd", + "1ecacf57f20c0d1e4a04af0c8d4b54a3" +)] +#[case::stir( + "stir-5.0.2-py38h9224444_7.conda", + "352fe747f7f09b09baa4b6561485b3f0d4271f6f798d34dae7116c3c9c6ba896", + "7bb9eb9ddaaf4505777512c5ad2fc108" +)] +fn conda_archives(#[case] input: &str, #[case] sha256: &str, #[case] md5: &str) {} + +#[template] +#[rstest] +#[case::conda( + "conda-22.9.0-py38haa244fe_2.tar.bz2", + "3c2c2e8e81bde5fb1ac4b014f51a62411feff004580c708c97a0ec2b7058cdc4", + "36194591e28b9f2c107aa3d952ac4649" +)] +#[case::mamba( + "mamba-1.0.0-py38hecfeebb_2.tar.bz2", + "f44c4bc9c6916ecc0e33137431645b029ade22190c7144eead61446dcbcc6f97", + "dede6252c964db3f3e41c7d30d07f6bf" +)] +#[case::micromamba( + "micromamba-1.1.0-0.tar.bz2", + "5a1e1fe69a301e817cf2795ace03c9e4a42e97cd8984b6edbc8872dad00d5097", + "3774689d66819fb50ff87fccefff6e88" +)] +#[case::mock( + "mock-2.0.0-py37_1000.tar.bz2", + "34c659b0fdc53d28ae721fd5717446fb8abebb1016794bd61e25937853f4c29c", + "0f9cce120a73803a70abb14bd4d4900b" +)] +#[case::pytweening( + "pytweening-1.0.4-pyhd8ed1ab_0.tar.bz2", + "81644bcb60d295f7923770b41daf2d90152ef54b9b094c26513be50fccd62125", + "d5e0fafeaa727f0de1c81bfb6e0e63d8" +)] +#[case::rosbridge( + "ros-noetic-rosbridge-suite-0.11.14-py39h6fdeb60_14.tar.bz2", + "4dd9893f1eee45e1579d1a4f5533ef67a84b5e4b7515de7ed0db1dd47adc6bc8", + "47d2678d67ec7ebd49ade2b9943e597e" +)] +#[case::zlib( + "zlib-1.2.8-vc10_0.tar.bz2", + "ee9172dbe9ebd158e8e68d6d0f7dc2060f0c8230b44d2e9a3595b7cd7336b915", + "8415564d07857a1069c0cd74e7eeb369" +)] +fn tar_bz2_archives(#[case] input: &str, #[case] sha256: &str, #[case] md5: &str) {} + +#[template] +#[rstest] +#[case::ruff( + "https://conda.anaconda.org/conda-forge/win-64/ruff-0.0.205-py39h5b3f8fb_0.conda", + "8affd54b71aabc63ddc3944135a4b31462b09da7d1677a53cd31df50ffe07173", + "bdfa0d81d2337eb713a66119754ad67a" +)] +#[case::python( + "https://conda.anaconda.org/conda-forge/win-64/python-3.11.0-hcf16a7b_0_cpython.tar.bz2", + "20d1f1b5dc620b745c325844545fd5c0cdbfdb2385a0e27ef1507399844c8c6d", + "13ee3577afc291dabd2d9edc59736688" +)] +fn url_archives(#[case] input: &str, #[case] sha256: &str, #[case] md5: &str) {} + +#[apply(conda_archives)] +fn test_extract_conda(#[case] input: &str, #[case] sha256: &str, #[case] md5: &str) { let temp_dir = Path::new(env!("CARGO_TARGET_TMPDIR")); println!("Target dir: {}", temp_dir.display()); - - for file_path in - find_all_archives().filter(|path| ArchiveType::try_from(path) == Some(ArchiveType::Conda)) - { - println!("Name: {}", file_path.display()); - - let target_dir = temp_dir.join(file_path.file_stem().unwrap()); - extract_conda(File::open(&file_path).unwrap(), &target_dir).unwrap(); - } + let file_path = Path::new(input); + let target_dir = temp_dir.join(file_path.file_stem().unwrap()); + let result = extract_conda( + File::open(&test_data_dir().join(file_path)).unwrap(), + &target_dir, + ) + .unwrap(); + + assert_eq!(&format!("{:x}", result.sha256), sha256); + assert_eq!(&format!("{:x}", result.md5), md5); } -#[test] -fn test_stream_info() { +#[apply(conda_archives)] +fn test_stream_info(#[case] input: &str, #[case] _sha256: &str, #[case] _md5: &str) { let temp_dir = Path::new(env!("CARGO_TARGET_TMPDIR")); println!("Target dir: {}", temp_dir.display()); - for file_path in - find_all_archives().filter(|path| ArchiveType::try_from(path) == Some(ArchiveType::Conda)) - { - println!("Name: {}", file_path.display()); + let file_path = Path::new(input); - let mut info_stream = - rattler_package_streaming::seek::stream_conda_info(File::open(&file_path).unwrap()) - .unwrap(); + let mut info_stream = rattler_package_streaming::seek::stream_conda_info( + File::open(&test_data_dir().join(file_path)).unwrap(), + ) + .unwrap(); - let target_dir = temp_dir.join(format!( - "{}-info", - file_path.file_stem().unwrap().to_string_lossy().as_ref() - )); + let target_dir = temp_dir.join(format!( + "{}-info", + file_path.file_stem().unwrap().to_string_lossy().as_ref() + )); - info_stream.unpack(target_dir).unwrap(); - } + info_stream.unpack(target_dir).unwrap(); } -#[test] -fn test_extract_tar_bz2() { +#[apply(tar_bz2_archives)] +fn test_extract_tar_bz2(#[case] input: &str, #[case] sha256: &str, #[case] md5: &str) { let temp_dir = Path::new(env!("CARGO_TARGET_TMPDIR")); println!("Target dir: {}", temp_dir.display()); - for file_path in - find_all_archives().filter(|path| ArchiveType::try_from(path) == Some(ArchiveType::TarBz2)) - { - println!("Name: {}", file_path.display()); + let file_path = Path::new(input); - let target_dir = temp_dir.join(file_path.file_stem().unwrap()); - extract_tar_bz2(File::open(&file_path).unwrap(), &target_dir).unwrap(); - } + let target_dir = temp_dir.join(file_path.file_stem().unwrap()); + let result = extract_tar_bz2( + File::open(&test_data_dir().join(file_path)).unwrap(), + &target_dir, + ) + .unwrap(); + + assert_eq!(&format!("{:x}", result.sha256), sha256); + assert_eq!(&format!("{:x}", result.md5), md5); } #[cfg(feature = "tokio")] +#[apply(tar_bz2_archives)] #[tokio::test] -async fn test_extract_tar_bz2_async() { +async fn test_extract_tar_bz2_async(#[case] input: &str, #[case] sha256: &str, #[case] md5: &str) { let temp_dir = Path::new(env!("CARGO_TARGET_TMPDIR")).join("tokio"); println!("Target dir: {}", temp_dir.display()); - for file_path in - find_all_archives().filter(|path| ArchiveType::try_from(path) == Some(ArchiveType::TarBz2)) - { - println!("Name: {}", file_path.display()); - - let target_dir = temp_dir.join(file_path.file_stem().unwrap()); - rattler_package_streaming::tokio::async_read::extract_tar_bz2( - tokio::fs::File::open(&file_path).await.unwrap(), - &target_dir, - ) - .await - .unwrap(); - } + let file_path = Path::new(input); + let target_dir = temp_dir.join(file_path.file_stem().unwrap()); + let result = rattler_package_streaming::tokio::async_read::extract_tar_bz2( + tokio::fs::File::open(&test_data_dir().join(file_path)) + .await + .unwrap(), + &target_dir, + ) + .await + .unwrap(); + + assert_eq!(&format!("{:x}", result.sha256), sha256); + assert_eq!(&format!("{:x}", result.md5), md5); } #[cfg(feature = "tokio")] +#[apply(conda_archives)] #[tokio::test] -async fn test_extract_conda_async() { +async fn test_extract_conda_async(#[case] input: &str, #[case] sha256: &str, #[case] md5: &str) { let temp_dir = Path::new(env!("CARGO_TARGET_TMPDIR")).join("tokio"); println!("Target dir: {}", temp_dir.display()); - for file_path in - find_all_archives().filter(|path| ArchiveType::try_from(path) == Some(ArchiveType::Conda)) - { - println!("Name: {}", file_path.display()); - - let target_dir = temp_dir.join(file_path.file_stem().unwrap()); - rattler_package_streaming::tokio::async_read::extract_conda( - tokio::fs::File::open(&file_path).await.unwrap(), - &target_dir, - ) - .await - .unwrap(); - } + let file_path = Path::new(input); + + let target_dir = temp_dir.join(file_path.file_stem().unwrap()); + let result = rattler_package_streaming::tokio::async_read::extract_conda( + tokio::fs::File::open(&test_data_dir().join(file_path)) + .await + .unwrap(), + &target_dir, + ) + .await + .unwrap(); + + assert_eq!(&format!("{:x}", result.sha256), sha256); + assert_eq!(&format!("{:x}", result.md5), md5); } #[cfg(feature = "reqwest")] -#[test] -fn test_extract_url() { +#[apply(url_archives)] +fn test_extract_url(#[case] url: &str, #[case] sha256: &str, #[case] md5: &str) { let temp_dir = Path::new(env!("CARGO_TARGET_TMPDIR")); println!("Target dir: {}", temp_dir.display()); - for url in [ - "https://conda.anaconda.org/conda-forge/win-64/ruff-0.0.205-py39h5b3f8fb_0.conda", - "https://conda.anaconda.org/conda-forge/win-64/python-3.11.0-hcf16a7b_0_cpython.tar.bz2", - ] { - let (_, filename) = url.rsplit_once('/').unwrap(); - let name = Path::new(filename); - println!("Name: {}", name.display()); + let (_, filename) = url.rsplit_once('/').unwrap(); + let name = Path::new(filename); + println!("Name: {}", name.display()); - let target_dir = temp_dir.join(name); + let target_dir = temp_dir.join(name); + let result = rattler_package_streaming::reqwest::extract(Default::default(), url, &target_dir).unwrap(); - } + + assert_eq!(&format!("{:x}", result.sha256), sha256); + assert_eq!(&format!("{:x}", result.md5), md5); } #[cfg(all(feature = "reqwest", feature = "tokio"))] +#[apply(url_archives)] #[tokio::test] -async fn test_extract_url_async() { +async fn test_extract_url_async(#[case] url: &str, #[case] sha256: &str, #[case] md5: &str) { let temp_dir = Path::new(env!("CARGO_TARGET_TMPDIR")).join("tokio"); println!("Target dir: {}", temp_dir.display()); - for url in [ - "https://conda.anaconda.org/conda-forge/win-64/ruff-0.0.205-py39h5b3f8fb_0.conda", - "https://conda.anaconda.org/conda-forge/win-64/python-3.11.0-hcf16a7b_0_cpython.tar.bz2", - ] { - let (_, filename) = url.rsplit_once('/').unwrap(); - let name = Path::new(filename); - println!("Name: {}", name.display()); + let (_, filename) = url.rsplit_once('/').unwrap(); + let name = Path::new(filename); + println!("Name: {}", name.display()); - let target_dir = temp_dir.join(name); - let url = url::Url::parse(url).unwrap(); + let target_dir = temp_dir.join(name); + let url = url::Url::parse(url).unwrap(); + let result = rattler_package_streaming::reqwest::tokio::extract(Default::default(), url, &target_dir) .await .unwrap(); - } + + assert_eq!(&format!("{:x}", result.sha256), sha256); + assert_eq!(&format!("{:x}", result.md5), md5); }